import os import sys with open(sys.argv[0]) as f: code = f.read() # read the code of this file ASAP, for logging import copy import glob import math import threading import time import uuid from dataclasses import dataclass from collections import defaultdict from itertools import accumulate from pathlib import Path os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import torch torch.empty( 1, device="cuda", requires_grad=True ).backward() # prevents a bug on some systems import torch._dynamo as dynamo import torch.distributed as dist import torch.nn.functional as F # torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min import triton import triton.language as tl from kernels import get_kernel from torch import Tensor, nn dynamo.config.recompile_limit = 64 # ----------------------------------------------------------------------------- # Custom operators: FP8 matmul by @YouJiacheng @torch.library.custom_op("nanogpt::mm", mutates_args=()) def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: @torch.compile def impl(x: Tensor, w: Tensor): assert x.is_contiguous() and w.is_contiguous() x_f8 = x.div(x_s).to(torch.float8_e4m3fn) w_f8 = w.div(w_s).to(torch.float8_e4m3fn) out = torch._scaled_mm( x_f8, w_f8.T, out_dtype=torch.bfloat16, scale_a=x.new_tensor(x_s, dtype=torch.float32), scale_b=x.new_tensor(w_s, dtype=torch.float32), use_fast_accum=True, ) return out, x_f8, w_f8 return impl(x, w) @mm_op.register_fake def _(x: Tensor, w: Tensor, *_): assert x.ndim == w.ndim == 2 assert x.shape[1] == w.shape[1] assert x.device == w.device assert x.is_contiguous() and w.is_contiguous() return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) @torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: @torch.compile def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): assert grad.is_contiguous() x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) grad_x = torch._scaled_mm( grad_f8, w_f8.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=grad_inv_s, scale_b=w_inv_s, use_fast_accum=False, ) # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) grad_w = torch._scaled_mm( x_f8.T.contiguous(), grad_f8.T.contiguous().T, out_dtype=torch.float32, scale_a=x_inv_s, scale_b=grad_inv_s, use_fast_accum=False, ).T return grad_x, grad_w return impl(g, x_f8, w_f8) @mm_backward_op.register_fake def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) def backward(ctx, grad_out: Tensor, *_): x_f8, w_f8 = ctx.saved_tensors x_s, w_s, grad_s = ctx.scales grad_x, grad_w = torch.ops.nanogpt.mm_backward( grad_out, x_f8, w_f8, x_s, w_s, grad_s ) return grad_x, grad_w, None, None, None def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): *_, x_s, w_s, grad_s = inputs _, x_f8, w_f8 = output ctx.save_for_backward(x_f8, w_f8) ctx.scales = x_s, w_s, grad_s ctx.set_materialize_grads(False) mm_op.register_autograd(backward, setup_context=setup_context) # ----------------------------------------------------------------------------- # Triton kernel for symmetric matrix multiplication by @byronxu99 def _get_autotune_configs(): return [ triton.Config( { "BLOCK_SIZE_M": bm, "BLOCK_SIZE_N": bn, "BLOCK_SIZE_K": bk, "GROUP_SIZE_M": 8, "LOWER_UPPER": 1, }, num_stages=stages, num_warps=warps, ) for bm in [64, 128] for bn in [64, 128, 256] for bk in [64, 128] for stages, warps in [(3, 4), (3, 8), (4, 4)] if bm // bn <= 2 and bn // bm <= 2 ] @triton.jit def _pid_to_block( pid, M, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) # Map PID to a single matrix in batch batch_idx = pid // (num_pid_m * num_pid_n) pid = pid % (num_pid_m * num_pid_n) # Map PID to 2D grid of blocks pid_m = pid // num_pid_n pid_n = pid % num_pid_n pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) m_idx = pid_m * BLOCK_SIZE_M n_idx = pid_n * BLOCK_SIZE_N return batch_idx, m_idx, n_idx @triton.autotune( configs=_get_autotune_configs(), key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], ) @triton.jit def XXT_kernel( A_ptr, C_ptr, M, K, a_stride_b, a_stride_r, a_stride_c, c_stride_b, c_stride_r, c_stride_c, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, LOWER_UPPER: tl.constexpr, ): pid = tl.program_id(axis=0) batch_idx, m_idx, n_idx = _pid_to_block( pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M ) # Skip blocks that don't need to be computed skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) if skip_block_below_diag or skip_block_above_diag: return # Index into one matrix of batch A_ptr += batch_idx * a_stride_b C_ptr += batch_idx * c_stride_b # Create pointer arrays for A and A.T offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Accumulate over blocks of K for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, at, accumulator) a_ptrs += BLOCK_SIZE_K * a_stride_c at_ptrs += BLOCK_SIZE_K * a_stride_c out_dtype = C_ptr.dtype.element_ty output = accumulator.to(out_dtype) # Store block of C offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) tl.store(c_ptrs, output, mask=c_mask) # Store block of C mirrored across the diagonal c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) tl.store(c_ptrs_t, output.T, mask=c_mask_t) def XXT(A: torch.Tensor, out: torch.Tensor): """ Launch Triton kernel to compute C = A @ A.T """ assert A.ndim == 2 or A.ndim == 3 M, K = A.shape[-2:] assert out.size(-2) == M, "Output matrix has incorrect shape" assert out.size(-1) == M, "Output matrix has incorrect shape" batch_size = A.size(0) if A.ndim == 3 else 1 input_batch_stride = A.stride(0) if A.ndim == 3 else 0 output_batch_stride = out.stride(0) if out.ndim == 3 else 0 grid = lambda meta: ( batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), ) XXT_kernel[grid]( A_ptr=A, C_ptr=out, M=M, K=K, a_stride_b=input_batch_stride, a_stride_r=A.stride(-2), a_stride_c=A.stride(-1), c_stride_b=output_batch_stride, c_stride_r=out.stride(-2), c_stride_c=out.stride(-1), ) return out @triton.autotune( configs=_get_autotune_configs(), key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], ) @triton.jit def ba_plus_cAA_kernel( A_ptr, C_ptr, M, a_stride_b, a_stride_r, a_stride_c, c_stride_b, c_stride_r, c_stride_c, alpha, beta, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, LOWER_UPPER: tl.constexpr, ): # This is mostly duplicated from XXT_kernel, but also loads and adds a block of A # Performance is slightly slower than XXT_kernel, so we use two separate kernels pid = tl.program_id(axis=0) batch_idx, m_idx, n_idx = _pid_to_block( pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M ) # Skip blocks that don't need to be computed skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) if skip_block_below_diag or skip_block_above_diag: return # Index into one matrix of batch A_ptr += batch_idx * a_stride_b C_ptr += batch_idx * c_stride_b # Create pointer arrays for A and A.T offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Accumulate over blocks of K for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, at, accumulator) a_ptrs += BLOCK_SIZE_K * a_stride_c at_ptrs += BLOCK_SIZE_K * a_stride_c # Load block of A to add (corresponds to the current block of C) offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) # Apply alpha and beta accumulator *= alpha accumulator += a_add * beta out_dtype = C_ptr.dtype.element_ty output = accumulator.to(out_dtype) # Store block of C offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) tl.store(c_ptrs, output, mask=c_mask) # Store block of C mirrored across the diagonal c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) tl.store(c_ptrs_t, output.T, mask=c_mask_t) def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): """ Launch Triton kernel to compute C = alpha * A @ A.T + beta * A """ assert A.ndim == 2 or A.ndim == 3 M, K = A.shape[-2:] assert M == K, "Input matrix must be square" assert out.size(-2) == M assert out.size(-1) == M batch_size = A.size(0) if A.ndim == 3 else 1 input_batch_stride = A.stride(0) if A.ndim == 3 else 0 output_batch_stride = out.stride(0) if out.ndim == 3 else 0 grid = lambda meta: ( batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), ) ba_plus_cAA_kernel[grid]( A_ptr=A, C_ptr=out, M=M, a_stride_b=input_batch_stride, a_stride_r=A.stride(-2), a_stride_c=A.stride(-1), c_stride_b=output_batch_stride, c_stride_r=out.stride(-2), c_stride_c=out.stride(-1), alpha=alpha, beta=beta, ) return out # Computed for num_iters=5, safety_factor=2e-2, cushion=2 polar_express_coeffs = [ (8.156554524902461, -22.48329292557795, 15.878769915207462), (4.042929935166739, -2.808917465908714, 0.5000178451051316), (3.8916678022926607, -2.772484153217685, 0.5060648178503393), (3.285753657755655, -2.3681294933425376, 0.46449024233003106), (2.3465413258596377, -1.7097828382687081, 0.42323551169305323) ] @torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower def polar_express(G: torch.Tensor): """ Polar Express Sign Method: https://arxiv.org/pdf/2505.16932 by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower. """ X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6) # Allocate buffers X = X.contiguous() A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) B = torch.empty_like(A) C = torch.empty_like(X) aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm # Perform the iterations for a, b, c in polar_express_coeffs: XXT(X, out=A) # A = X @ X.mT ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X X, C = C, X # Swap references to avoid unnecessary copies if G.size(-2) > G.size(-1): X = X.mT return X # ----------------------------------------------------------------------------- # Muon optimizer class NorMuon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz https://kellerjordan.github.io/posts/muon/ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). Differences from standard Muon: - Newton-Shulz is replaced with Polar Express for the orthogonalization step - NorMuon adds a low-rank variance estimator similar to Adafactor. - small 1D parameters handled here instead of in Adam - Cautious weight decay, a gated version of decoupled weight decay - Custom distributed sizing: The model stores all attn and mlp weights in the same shape, and then updates the view as needed on the forward pass. This enables attn and mlp weights to be contained within the same dist.reduce_scatter_tensor() call. The model architecture has been customized to enable (n_attn_layers+n_mlp_layers*2)%8==0 for batching across 8 GPUs with zero padding on mlp and attn. The scheduling is: 1. reduce scatter smear_gate (1 param 7 padding params) 2. reduce scatter attn_gate (10 params 6 padding params) 3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params) 4. reduce scatter attn/mlp round 2 (16 mlp params) 5. wait on step 1, then compute update of 1 and schedule all gather 6. wait on step 2, then compute update of 2 and schedule all gather 7. wait on step 3, then compute update of 3 and schedule all gather GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP] GPUs that receive params of type attn reshape before computing update 8. wait on 4, then compute update of 4 and schedule all gather 9. wait for each all gather to complete and update params Empirically, leading with small params provides an additional 0.2s improvement. """ def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, beta2=0.95, custom_sizing=True): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, beta2=beta2) self.world_size = dist.get_world_size() if dist.is_initialized() else 1 # custom sizing requires 8 GPUs if custom_sizing and dist.get_world_size()==8: param_groups = self.generate_custom_param_groups(params) else: param_groups = self.generate_standard_param_groups(params) super().__init__(param_groups, defaults) def reset(self): # expose a reset for clearing buffers for group in self.param_groups: group["momentum_buffer"].zero_() group["second_momentum_buffer"].zero_() def generate_standard_param_groups(self, params): """ Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules. Creates one param group per module. """ groups = defaultdict(list) for param in params: groups[param.label].append(param) param_groups = [] for module_name, group_params in groups.items(): chunk_size = (len(group_params) + self.world_size - 1) // self.world_size param_groups.append(dict(params=group_params, chunk_size=chunk_size)) return param_groups def generate_custom_param_groups(self, params): """ Implementation requires that a single GPU does not receive both attn and mlp params when a param group is split across GPUs. """ module_group_order = ['smear_gate', 'attn_gate', 'attn', 'mlp'] params_list = list(params) params_list.sort(key=lambda x: module_group_order.index(x.label)) idx = 0 group_sizes = [1, 10, 16, 16] assert len(params_list) == sum(group_sizes) param_groups = [] for size in group_sizes: chunk_size = (size + self.world_size - 1) // self.world_size group_params = params_list[idx: idx + size] param_groups.append(dict(params=group_params, chunk_size=chunk_size)) idx += size return param_groups @torch.no_grad() def step(self): # Efficient systems-wise implementation of step developed by @YouJiacheng, # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, # @ryanyang0, @vagrawal, and @varunneal. rank = dist.get_rank() group_infos = [] for group in self.param_groups: params: list[Tensor] = group["params"] if not params: continue chunk_size = group["chunk_size"] padded_num_params = chunk_size * self.world_size stacked_grads = torch.empty( (padded_num_params, *params[0].shape), dtype=params[0].dtype, device=params[0].device ) for i, p in enumerate(params): stacked_grads[i].copy_(p.grad, non_blocking=True) if len(params) < padded_num_params: stacked_grads[len(params):].zero_() grad_chunk = torch.empty_like(stacked_grads[:chunk_size]) reduce_future = dist.reduce_scatter_tensor( grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True ).get_future() group_infos.append(dict(grad_chunk=grad_chunk, reduce_future=reduce_future)) all_gather_infos = [] # Second pass: wait for gradients, compute updates for the local shard of parameters, # and launch all async all_gather operations. for group, info in zip(self.param_groups, group_infos): info["reduce_future"].wait() params = group["params"] grad_chunk = info["grad_chunk"] chunk_size = group["chunk_size"] padded_num_params = chunk_size * self.world_size start_idx = rank * chunk_size module_idx = start_idx if start_idx < len(params) else 0 num_params = min(chunk_size, max(0, len(params) - start_idx)) # num params for this rank if "momentum_buffer" not in group: group["momentum_buffer"] = torch.zeros_like(grad_chunk[:num_params]) momentum_buffer = group["momentum_buffer"] # Apply momentum update to the persistent momentum buffer in-place momentum_buffer.lerp_(grad_chunk[:num_params], 1 - group["momentum"]) updated_grads = grad_chunk[:num_params].lerp_(momentum_buffer, group["momentum"]) grad_shape = updated_grads.shape if params[module_idx].label == 'attn': # Reshape attn params from [hdim, dim*4] to [4,hdim,dim] for p in params[module_idx:module_idx + num_params]: assert p.label == 'attn' updated_grads = updated_grads.view(4 * grad_shape[0], grad_shape[1], grad_shape[2] // 4) ref_param = params[module_idx] param_shape = ref_param.shape if "second_momentum_buffer" not in group: group["second_momentum_buffer"] = (torch.zeros_like(updated_grads[..., :, :1]) if param_shape[-2] >= param_shape[-1] else torch.zeros_like(updated_grads[..., :1, :]) ) second_momentum_buffer = group["second_momentum_buffer"] if "param_lr" not in group: group["param_lr"] = ( max(1., param_shape[-2] / param_shape[-1]) ** 0.5 * ref_param.new_tensor( [getattr(param, "lr_mul", 1.0) for param in params[module_idx:module_idx + num_params]] ).view(-1, 1, 1) ) group["param_wd"] = ref_param.new_tensor( [getattr(param, "wd_mul", 1.0) for param in params[module_idx:module_idx + num_params]] ).view(-1, 1, 1) # Determine LR and WR eff_lr = group["lr"] * group["param_lr"] eff_wd = group["lr"] * group["weight_decay"] * group["param_wd"] # Compute zeropower for the entire chunk in a single, batched call. if num_params == 0: v_chunk = updated_grads else: v_chunk = polar_express(updated_grads) # NorMuon: second_momentum_buffer tracks squared magnitude of gradients along one dim (https://arxiv.org/pdf/2510.05491) v_norm = v_chunk.norm(dim=(-2, -1), keepdim=True) v_mean = v_chunk.square().mean(dim=-1 if param_shape[-2] >= param_shape[-1] else -2, keepdim=True) second_momentum_buffer.lerp_(v_mean.to(dtype=ref_param.dtype), 1 - group["beta2"]) step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_() v_chunk.mul_(step_size) v_norm_new = v_chunk.norm(dim=(-2, -1), keepdim=True) v_chunk.mul_(v_norm / v_norm_new.clamp_min_(1e-10)) v_chunk = v_chunk.view(grad_shape) updated_params = torch.empty_like(grad_chunk) param_chunk = torch.stack(params[module_idx:module_idx + num_params]) if num_params > 0 else torch.zeros_like(v_chunk) # "Cautious" weight decay (https://arxiv.org/abs/2510.12402) mask = (v_chunk * param_chunk) >= 0 v_chunk.addcmul_(param_chunk, (eff_wd * mask).to(ref_param.dtype)) param_chunk.addcmul_(v_chunk, -eff_lr) updated_params[:num_params].copy_(param_chunk) if num_params < chunk_size: updated_params[num_params:].zero_() stacked_params = torch.empty( (padded_num_params, *param_shape), dtype=updated_params.dtype, device=updated_params.device, ) gather_future = dist.all_gather_into_tensor( stacked_params, updated_params, async_op=True ).get_future() all_gather_infos.append( { "gather_future": gather_future, "stacked_params": stacked_params, "orig_params": params, } ) # Final pass: wait for all_gather to complete and copy results back into original parameter tensors. for info in all_gather_infos: info["gather_future"].wait() stacked_params = info["stacked_params"] orig_params = info["orig_params"] unstacked_params = torch.unbind(stacked_params) for i, p in enumerate(orig_params): p.copy_(unstacked_params[i], non_blocking=True) class DistAdam(torch.optim.Optimizer): def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): self.world_size = dist.get_world_size() if dist.is_initialized() else 1 defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) params = list(params) sizes = {p.shape for p in params} # create one buffer per unique parameter-size param_groups = [] for size in sizes: group_params = [p for p in params if p.shape == size] param_groups.append(dict(params=group_params)) super().__init__(param_groups, defaults) # init state for p in params: chunk_size = p.size(0) // self.world_size exp_avg = torch.zeros_like(p[:chunk_size], dtype=torch.bfloat16, device=p[0].device) exp_avg_sq = torch.zeros_like(exp_avg) self.state[p] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=exp_avg_sq) # DistributedAdam implementation by @vagrawal, @akash5474 self.should_sync = False self._reduce_scatter_hooks = [] self._reduce_scatter_futures = {} self.register_backward_hooks() def register_backward_hooks(self): for group in self.param_groups: params: list[Tensor] = group["params"] for param in params: hook = param.register_post_accumulate_grad_hook(self._sync_gradient) self._reduce_scatter_hooks.append(hook) @torch.compile @torch.no_grad() def _sync_gradient(self, param): if not self.should_sync: return grad = param.grad rank_size = grad.shape[0] // self.world_size grad_slice = torch.empty_like(grad[:rank_size]) self._reduce_scatter_futures[param] = ( dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future(), grad_slice ) @torch.compile @torch.no_grad() def step(self): rank = dist.get_rank() all_gather_futures: list[torch.Future] = [] for group in reversed(self.param_groups): beta1, beta2 = group['betas'] eps = group['eps'] wd = group['weight_decay'] for param in reversed(group['params']): if param not in self._reduce_scatter_futures: continue fut, g_slice = self._reduce_scatter_futures[param] fut.wait() rank_size = param.shape[0] // self.world_size p_slice = param[rank * rank_size:(rank + 1) * rank_size] lr = group['lr'] * getattr(param, "lr_mul", 1.0) state = self.state[param] exp_avg = state["exp_avg"] exp_avg_sq = state["exp_avg_sq"] state["step"] += 1 t = state["step"] # weight decay if wd != 0: eff_weight_decay = lr * wd * getattr(param, "wd_mul", 1.0) p_slice.mul_(1 - eff_weight_decay) # update running averages exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) # bias corrections bias1 = 1 - beta1 ** t bias2 = 1 - beta2 ** t # compute step denom = exp_avg_sq.sqrt().add_(eps) step_size = lr * (bias2 ** 0.5 / bias1) update = exp_avg.div(denom).mul_(step_size) p_slice.add_(other=update, alpha=-1.0) all_gather_futures.append(dist.all_gather_into_tensor(param, p_slice, async_op=True).get_future()) self._reduce_scatter_futures.clear() torch.futures.collect_all(all_gather_futures).wait() # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the model def norm(x: Tensor): return F.rms_norm(x, (x.size(-1),)) class CastedLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): super().__init__(in_features, out_features, bias=False) self.use_fp8 = use_fp8 self.x_s = x_s self.w_s = w_s self.grad_s = grad_s def reset_parameters(self) -> None: with torch.no_grad(): self.weight.zero_() # @Grad62304977 and others def forward(self, x: Tensor): if self.use_fp8 and self.training: _x = x.flatten(0, -2) out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] return out.reshape(*x.shape[:-1], -1) else: return F.linear(x, self.weight.type_as(x)) # yarn implementation @classiclarryd class Yarn(nn.Module): def __init__(self, head_dim, max_seq_len): super().__init__() self.head_dim = head_dim self.max_seq_len = max_seq_len self.reset() def reset(self): angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device) # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)]) t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device) theta = torch.outer(t, angular_freq) self.cos = nn.Buffer( theta.cos().to(torch.bfloat16), persistent=False ) self.sin = nn.Buffer( theta.sin().to(torch.bfloat16), persistent=False ) self.angular_freq = angular_freq # start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 self.attn_scale = 0.1 def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32): rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi) scaling_factor = old_window / new_window interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1) self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor) t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device) theta = torch.outer(t, self.angular_freq) self.cos.copy_(theta.cos()) self.sin.copy_(theta.sin()) self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1 def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor): assert cos.size(0) >= x_BTHD.size(-3) cos, sin = ( cos[None, : x_BTHD.size(-3), None, :], sin[None, : x_BTHD.size(-3), None, :], ) x1, x2 = x_BTHD.chunk(2, dim=-1) y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat((y1, y2), 3) @dataclass class AttnArgs: ve: torch.Tensor sa_lambdas: torch.Tensor seqlens: torch.Tensor bm_size: int cos: torch.Tensor sin: torch.Tensor attn_scale: float flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface class CausalSelfAttention(nn.Module): def __init__(self, dim: int, head_dim: int, num_heads: int): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.dim = dim self.hdim = num_heads * head_dim assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim" std = 0.5 * (self.dim ** -0.5) bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng # https://x.com/hi_tysam/status/1879699187107033311 # make matrices the same shape as MLP to enable batched call in optimizer self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4)) # label module to enable custom optimizer sizing self.qkvo_w.label='attn' with torch.no_grad(): self.qkvo_w.view(4,self.hdim, self.dim)[:3].uniform_(-bound, bound) # init QKV weights self.qkvo_w.view(4,self.hdim, self.dim)[3].zero_() # init output weights to zero # sparse gated attention to enable context based no-op by @classiclarryd self.attn_gate = CastedLinear(12, num_heads) # label module to enable custom optimizer sizing self.attn_gate.weight.label = 'attn_gate' def forward(self, x: Tensor, attn_args: AttnArgs): B, T = x.size(0), x.size(1) # batch size, sequence length assert B == 1, "varlen sequences requires B == 1" assert T % 16 == 0 # unpack attention args cos, sin = attn_args.cos, attn_args.sin ve, sa_lambdas = attn_args.ve, attn_args.sa_lambdas seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size q, k, v = F.linear(x, sa_lambdas[0] * self.qkvo_w.view(4, self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 q, k = rotary(q, cos, sin), rotary(k, cos, sin) if ve is not None: v = v + sa_lambdas[1] * ve.view_as(v) # @ KoszarskyB & @Grad62304977 max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0)) y = y.view(B, T, self.num_heads, self.head_dim) y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side y = F.linear(y, self.qkvo_w.view(4, self.hdim, self.dim)[3].type_as(y)) return y class MLP(nn.Module): def __init__(self, dim: int): super().__init__() hdim = 4 * dim # make matrices the same shape to enable batched call in optimizer self.c_fc = nn.Parameter(torch.empty(dim, hdim)) self.c_proj = nn.Parameter(torch.empty(dim, hdim)) # label modules to enable custom optimizer sizing self.c_fc.label = 'mlp' self.c_proj.label = 'mlp' # corrective factor to account for transpose self.c_fc.lr_mul = 2. std = 0.5 * (dim ** -0.5) bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng with torch.no_grad(): self.c_fc.uniform_(-bound, bound) self.c_proj.zero_() # zero init suggested by @Grad62304977 def forward(self, x: Tensor): x = F.linear(x, self.c_fc.T.type_as(x)) x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 x = F.linear(x, self.c_proj.type_as(x)) return x class Block(nn.Module): def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng self.attn = CausalSelfAttention(dim, head_dim, num_heads) if layer_idx not in [0, 7] else None # skip MLP blocks for first MLP layer by @EmelyanenkoK self.mlp = MLP(dim) if layer_idx != 0 else None def forward(self, x: Tensor, x0: Tensor, lambdas: Tensor, attn_args: AttnArgs): x = lambdas[0] * x + lambdas[1] * x0 if self.attn is not None: x = x + self.attn(norm(x), attn_args) if self.mlp is not None: x = x + self.mlp(norm(x)) return x # ----------------------------------------------------------------------------- # The main model def next_multiple_of_n(v: float | int, *, n: int): return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) class GPT(nn.Module): def __init__(self, vocab_size: int, num_layers: int, num_heads: int, head_dim: int, model_dim: int, max_seq_len: int): super().__init__() vocab_size = next_multiple_of_n(vocab_size, n=128) self.embed = nn.Embedding(vocab_size, model_dim) self.smear_gate = CastedLinear(12, 1) # label modules to enable custom optimizer sizing self.smear_gate.weight.label = 'smear_gate' # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) self.blocks = nn.ModuleList([Block(model_dim, head_dim, num_heads, i) for i in range(num_layers)]) self.yarn = Yarn(head_dim, max_seq_len) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. use_fp8 = not os.environ.get("DISABLE_FP8", False) self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 pad = (-num_layers * 5 - 2) % dist.get_world_size() self.scalars = nn.Parameter( torch.cat( [ -1.5 * torch.ones(num_layers), # skip_weights -> σ(-1.5) ≈ 0.18 *[ torch.tensor([1.1, 0.0]) for _ in range(num_layers) ], # block lambdas. 1.1 init such that layer i weight is i^(num_layers-i). # ~3x higher weight to layer 1 compared to 12 at init. *[ torch.tensor([0.5, 0.5]) for _ in range(num_layers) ], # SA lambdas torch.zeros(1), # smear_lambda 0.5*torch.ones(1), # backout_lambda torch.ones(pad), ] ) ) # set learning rates for param in self.embed.parameters(): param.lr_mul = 75. for param in self.value_embeds.parameters(): param.lr_mul = 75. self.lm_head.weight.lr_mul = 1.0 self.scalars.lr_mul = 5.0 def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws_short: int, ws_long: int): assert input_seq.ndim == 1 ve = [value_embed(input_seq) for value_embed in self.value_embeds] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure # dropping first layer updates this to .12 ... 012 ve = [None, ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] assert len(ve) == len(self.blocks) short_bm = ws_short * args.block_size long_bm = ws_long * args.block_size bm_sizes = [None, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, None, short_bm, short_bm, short_bm, long_bm] assert len(bm_sizes) == len(self.blocks) x = self.embed(input_seq) skip_weights = self.scalars[:(len(self.blocks) // 2)] lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) smear_lambda = self.scalars[5 * len(self.blocks)] backout_lambda = self.scalars[5 * len(self.blocks)+1] # smear token embed forward 1 position @classiclarryd smear_gate_out = smear_lambda * torch.sigmoid(self.smear_gate(x[1:, :self.smear_gate.weight.size(-1)])) x = torch.cat([x[:1], x[1:] + smear_gate_out * x[:-1]]) x = x0 = norm(x[None]) # create skip connection from layer 4 (long attn window) to layer 7 (no attn op) skip_connections = [] n = len(self.blocks) // 2 skip_in = [4] skip_out = [7] x_backout = None backout_layer = 8 # skip layer zero for i in range(1,len(self.blocks)): attn_args = AttnArgs( ve=ve[i], sa_lambdas=sa_lambdas[i], seqlens=seqlens, bm_size=bm_sizes[i], cos=self.yarn.cos, sin=self.yarn.sin, attn_scale=self.yarn.attn_scale ) # since layer 0 is skipped, layer 11 does not have skip_connection if i in skip_out: gate = torch.sigmoid(skip_weights[i - n]) # in (0, 1) x = x + gate * skip_connections.pop() x = self.blocks[i](x, x0, lambdas[i], attn_args) if i in skip_in: skip_connections.append(x) if i == backout_layer: x_backout = x # back out contributions from first 8 layers that are only required for downstream context and not direct prediction x -= backout_lambda * x_backout x = norm(x) logits = self.lm_head(x) # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits / 7.5) logits_for_loss = logits.float() if not self.training else logits loss = F.cross_entropy( logits_for_loss.view(-1, logits_for_loss.size(-1)), target_seq, reduction="sum" if self.training else "mean", ) return loss # ----------------------------------------------------------------------------- # Distributed data loader def _load_data_shard(file: Path): header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 assert header[0] == 20240520, "magic number mismatch in the data .bin file" assert header[1] == 1, "unsupported version" num_tokens = int(header[2]) # number of tokens (claimed) with file.open("rb", buffering=0) as f: tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng f.seek(256 * 4) nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng assert nbytes == 2 * num_tokens, "number of tokens read does not match header" return tokens BOS_ID = 50256 class BOSFinder: # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): # Precompute BOS positions once per shard self.tokens=tokens self.size = tokens.numel() self.quickload = quickload if quickload: # only scan first 4 million tokens, then kickoff async thread to scan rest self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() self.thread = None self.ready = threading.Event() self.start() else: self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() self.i = 0 self.world_size = world_size self.batch_iter = 0 def _load(self): self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() self.ready.set() def start(self): self.ready.clear() self.thread = threading.Thread(target=self._load) self.thread.start() def get(self): if self.thread: self.ready.wait() self.thread.join() self.bos_idx = self.bos_idx_async def next_batch(self, num_tokens_local: int, max_seq_len: int): # if quickload was used, repoint to the full dataset after 5 batches if self.quickload and self.batch_iter==5: self.get() n = len(self.bos_idx) starts = [[] for _ in range(self.world_size)] ends = [[] for _ in range(self.world_size)] idx = self.i for r in range(self.world_size): cur_len = 0 while cur_len <= num_tokens_local: if idx >= n: raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") cur = self.bos_idx[idx] starts[r].append(cur) end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, cur + max_seq_len, cur + num_tokens_local - cur_len + 1) ends[r].append(end) cur_len += end - cur idx += 1 assert cur_len == num_tokens_local + 1 self.i = idx self.batch_iter+=1 return starts, ends class DataPreloader: # Helper for asynchronously loading next shard and indexing bos tokens def __init__(self, file_iter, world_size: int = 1): self.file_iter = file_iter self.world_size = world_size self.thread = None self.data = None self.ready = threading.Event() def _load(self): tokens = _load_data_shard(next(self.file_iter)) self.data = (tokens, BOSFinder(tokens, self.world_size)) self.ready.set() def start(self): self.ready.clear() self.thread = threading.Thread(target=self._load) self.thread.start() def get(self): if self.thread: self.ready.wait() self.thread.join() return self.data def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" num_tokens = num_tokens // grad_accum_steps files = [Path(file) for file in sorted(glob.glob(filename_pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training tokens = _load_data_shard(next(file_iter)) if align_to_bos: finder = BOSFinder(tokens, world_size=world_size, quickload=True) preloader = DataPreloader(file_iter, world_size) preloader.start() else: pos = 0 # for unaligned case while True: num_tokens_local = num_tokens // world_size max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 if align_to_bos: try: seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) except StopIteration: # This shard is exhausted, load the next one in the next loop iteration. tokens, finder = preloader.get() preloader.start() continue buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) _inputs = buf[:-1] _targets = buf[1:] end_idxs[-1] -= 1 # last document was too long to account for _targets offset cum_lengths = (end_idxs - start_idxs).cumsum(0) else: if pos + num_tokens + 1 >= len(tokens): # should not occur for val data tokens, pos = _load_data_shard(next(file_iter)), 0 pos_local = pos + rank * num_tokens_local buf = tokens[pos_local: pos_local + num_tokens_local + 1] _inputs = buf[:-1].view(num_tokens_local, ) _targets = buf[1:].view(num_tokens_local, ) cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] pos += num_tokens _cum_lengths = torch.full((max_num_docs,), num_tokens_local) _cum_lengths[0] = 0 _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths new_params = yield ( _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) ) if new_params is not None: # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" num_tokens = new_num_tokens max_seq_len = new_max_seq_len grad_accum_steps = new_grad_accum_steps # ----------------------------------------------------------------------------- # int main @dataclass class Hyperparameters: # data train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons train_batch_size: int = 2048 * 16 * 8 train_max_seq_len: int = 128 * 16 val_batch_size: int = 4 * 64 * 1024 * 8 # optimization num_scheduled_iterations: int = 2185 # number of steps to complete lr and ws schedule num_extension_iterations: int = 40 # number of steps to continue training at final lr and ws num_iterations: int = num_scheduled_iterations + num_extension_iterations cooldown_frac: float = 0.50 # fraction of num_scheduled_iterations spent cooling down the learning rate # evaluation and logging run_id: str = f"{uuid.uuid4()}" val_loss_every: int = 250 # every how many steps to evaluate val loss? 0 for only at the end save_checkpoint: bool = False # attention masking block_size: int = 128 ws_schedule: tuple = (3, 7, 11) ws_final: int = 13 # increase final validation ws, used for YaRN extension and short window size @classiclarryd ws_validate_post_yarn_ext: int = 20 # extend long windows out even further after applying YaRN args = Hyperparameters() data_path = os.environ.get("DATA_PATH", ".") args.train_files = os.path.join(data_path, args.train_files) args.val_files = os.path.join(data_path, args.val_files) # torchrun sets these env variables rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) assert 8 % world_size == 0, "world_size must be a divisor of 8" grad_accum_steps = 8 // world_size assert torch.cuda.is_available() device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) torch.cuda.set_device(device) dist.init_process_group(backend="nccl", device_id=device) dist.barrier() master_process = (rank == 0) # this process will do logging, checkpointing etc. # begin logging logfile = None if master_process: run_id = args.run_id os.makedirs("logs", exist_ok=True) logfile = f"logs/{run_id}.txt" print(logfile) def print0(s, console=False): if master_process: with open(logfile, "a") as f: if console: print(s) print(s, file=f) # begin by printing this file (the Python code) print0(code) print0("="*100) # log information about the hardware/software environment this is running on print0(f"Running Python {sys.version}") print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") print0(f"Running Triton version {triton.__version__}") def nvidia_smi(): import subprocess # avoid top level import return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout print0(nvidia_smi()) print0("="*100) model: nn.Module = GPT( vocab_size=50257, num_layers=12, num_heads=6, head_dim=128, model_dim=768, max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) ).cuda() for m in model.modules(): if isinstance(m, (nn.Embedding, nn.Linear)): m.bfloat16() for param in model.parameters(): dist.broadcast(param.detach(), 0) # collect the parameters to optimize hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n and "gate" not in n] embed_params = [p for n, p in model.named_parameters() if "embed" in n] scalar_params = [p for p in model.parameters() if p.ndim < 2] head_params = [model.lm_head.weight] gate_params = [p for n, p in model.named_parameters() if "gate" in n] # init the optimizer(s) # small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence # discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 optimizer1 = DistAdam( scalar_params + head_params + embed_params, lr=0.008, betas=(0.65, 0.95), eps=1e-8, weight_decay=0.0, ) optimizer2 = NorMuon(hidden_matrix_params + gate_params, lr=0.03, momentum=0.95, beta2=0.95, weight_decay=1.2) optimizers = [optimizer1, optimizer2] for opt in optimizers: for group in opt.param_groups: group["initial_lr"] = group["lr"] # learning rate schedule: flat, then linear decay, then flat def get_lr(step: int): x = min(0.9999, step / args.num_scheduled_iterations) assert 0 <= x < 1 lr = 1.0 if x >= 1 - args.cooldown_frac: w = (1 - x) / args.cooldown_frac lr = w * 1.0 + (1 - w) * 0.1 return lr def get_ws(step: int): # set short window size to half of long window size # Higher ws on "extension" steps if step >= args.num_scheduled_iterations: return args.ws_final // 2, args.ws_final x = step / args.num_scheduled_iterations assert 0 <= x < 1 ws_idx = int(len(args.ws_schedule) * x) return args.ws_schedule[ws_idx] // 2, args.ws_schedule[ws_idx] def get_muon_momentum(step: int, muon_warmup_steps=300, muon_cooldown_steps=50, momentum_min=0.85, momentum_max=0.95): # warmup phase: linearly increase momentum from min to max # cooldown phase: linearly decrease momentum from max to min momentum_cd_start = args.num_iterations - muon_cooldown_steps if step < muon_warmup_steps: frac = step / muon_warmup_steps momentum = momentum_min + frac * (momentum_max - momentum_min) elif step > momentum_cd_start: frac = (step - momentum_cd_start) / muon_cooldown_steps momentum = momentum_max - frac * (momentum_max - momentum_min) else: momentum = momentum_max return momentum def step_optimizers(step: int, optimizers, model): # update lr for optimizer in optimizers: for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * get_lr(step) # set muon momentum based on step momentum = get_muon_momentum(step) for group in optimizers[1].param_groups: group["momentum"] = momentum # on even steps, only step Muon params # on odd steps, step all params if step%2==0: optimizers[1].step() optimizers[1].zero_grad(set_to_none=True) else: for optimizer in optimizers: optimizer.step() model.zero_grad(set_to_none=True) # disable sync in the next training step for the adam optimizer optimizers[0].should_sync = False model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) ######################################## # Warmup kernels # ######################################## # Warmup the training kernels, then re-initialize the state so we aren't cheating warmup_steps = 30 initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) ws_schedule = list(args.ws_schedule) + [args.ws_final] ws_long = ws_schedule[0] for step in range(warmup_steps): inputs, targets, cum_seqlens = next(train_loader) # each window size is a new graph, need to warm up each with Yarn.attn_scale ws_idx = step % len(ws_schedule) if ws_idx==0: model.yarn.reset() ws_long = ws_schedule[0] else: new_ws_long = ws_schedule[ws_idx] model.yarn.apply(ws_long, new_ws_long) ws_long = new_ws_long model(inputs, targets, cum_seqlens, ws_long//2, ws_long).backward() for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) model.yarn.reset() # rotary buffer is not stored in state_dict model.load_state_dict(initial_state["model"]) optimizer2.reset() # muon momentum buffers not in state dict for opt, opt_state in zip(optimizers, initial_state["optimizers"]): opt.load_state_dict(opt_state) del train_loader, initial_state ######################################## # Training and validation # ######################################## train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) training_time_ms = 0 # start the clock torch.cuda.synchronize() t0 = time.perf_counter() # begin training train_steps = args.num_iterations ws_short, ws_long = get_ws(0) for step in range(train_steps + 1): last_step = (step == train_steps) ws_short, new_ws_long = get_ws(step) if new_ws_long != ws_long: model.yarn.apply(ws_long, new_ws_long) ws_long=new_ws_long # --------------- VALIDATION SECTION ----------------- if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): if last_step: ws_long = args.ws_validate_post_yarn_ext # stop the clock torch.cuda.synchronize() training_time_ms += 1000 * (time.perf_counter() - t0) model.eval() assert args.val_tokens % args.val_batch_size == 0 val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) val_loss = 0 with torch.no_grad(): for _ in range(val_steps): inputs, targets, cum_seqlens = next(val_loader) val_loss += model(inputs, targets, cum_seqlens, ws_short, ws_long) val_loss /= val_steps del val_loader dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) model.train() # start the clock again torch.cuda.synchronize() t0 = time.perf_counter() if last_step: if master_process and args.save_checkpoint: log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) os.makedirs(f"logs/{run_id}", exist_ok=True) torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") # the last step only has the validation loop, so break to avoid training break # --------------- TRAINING SECTION ----------------- for idx in range(grad_accum_steps): # enable gradient sync for the DistAdam optimizer on the last iteration before we step it if idx == grad_accum_steps - 1 and step % 2 == 1: optimizers[0].should_sync = True inputs, targets, cum_seqlens = next(train_loader) (model(inputs, targets, cum_seqlens, ws_short, ws_long) / grad_accum_steps).backward() step_optimizers(step, optimizers, model) # logging approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) dist.destroy_process_group() ==================================================================================================== Running Python 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0] Running PyTorch 2.10.0.dev20251120+cu126 compiled for CUDA 12.6 Running Triton version 3.5.1 Thu Nov 20 12:19:06 2025 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | | N/A 32C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | | N/A 28C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | | N/A 26C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | | N/A 24C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | | N/A 26C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | | N/A 35C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | | N/A 25C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ | 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | | N/A 23C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 19327 C /usr/bin/python3 1510MiB | | 0 N/A N/A 19328 C /usr/bin/python3 614MiB | | 0 N/A N/A 19329 C /usr/bin/python3 614MiB | | 0 N/A N/A 19330 C /usr/bin/python3 614MiB | | 0 N/A N/A 19331 C /usr/bin/python3 614MiB | | 0 N/A N/A 19332 C /usr/bin/python3 614MiB | | 0 N/A N/A 19333 C /usr/bin/python3 614MiB | | 0 N/A N/A 19334 C /usr/bin/python3 614MiB | | 1 N/A N/A 19328 C /usr/bin/python3 1510MiB | | 2 N/A N/A 19329 C /usr/bin/python3 1510MiB | | 3 N/A N/A 19330 C /usr/bin/python3 1510MiB | | 4 N/A N/A 19331 C /usr/bin/python3 1510MiB | | 5 N/A N/A 19332 C /usr/bin/python3 1510MiB | | 6 N/A N/A 19333 C /usr/bin/python3 1510MiB | | 7 N/A N/A 19334 C /usr/bin/python3 1510MiB | +-----------------------------------------------------------------------------------------+ ==================================================================================================== step:0/2225 val_loss:10.8258 train_time:0ms step_avg:0.02ms step:1/2225 train_time:131ms step_avg:131.05ms step:2/2225 train_time:174ms step_avg:87.02ms step:3/2225 train_time:194ms step_avg:64.70ms step:4/2225 train_time:249ms step_avg:62.29ms step:5/2225 train_time:308ms step_avg:61.63ms step:6/2225 train_time:366ms step_avg:61.00ms step:7/2225 train_time:426ms step_avg:60.91ms step:8/2225 train_time:485ms step_avg:60.60ms step:9/2225 train_time:545ms step_avg:60.55ms step:10/2225 train_time:604ms step_avg:60.37ms step:11/2225 train_time:664ms step_avg:60.36ms step:12/2225 train_time:722ms step_avg:60.19ms step:13/2225 train_time:782ms step_avg:60.18ms step:14/2225 train_time:841ms step_avg:60.05ms step:15/2225 train_time:901ms step_avg:60.07ms step:16/2225 train_time:959ms step_avg:59.97ms step:17/2225 train_time:1020ms step_avg:59.99ms step:18/2225 train_time:1080ms step_avg:60.00ms step:19/2225 train_time:1143ms step_avg:60.17ms step:20/2225 train_time:1205ms step_avg:60.26ms step:21/2225 train_time:1267ms step_avg:60.31ms step:22/2225 train_time:1325ms step_avg:60.24ms step:23/2225 train_time:1386ms step_avg:60.27ms step:24/2225 train_time:1445ms step_avg:60.21ms step:25/2225 train_time:1506ms step_avg:60.25ms step:26/2225 train_time:1564ms step_avg:60.17ms step:27/2225 train_time:1625ms step_avg:60.18ms step:28/2225 train_time:1683ms step_avg:60.12ms step:29/2225 train_time:1744ms step_avg:60.12ms step:30/2225 train_time:1802ms step_avg:60.07ms step:31/2225 train_time:1863ms step_avg:60.08ms step:32/2225 train_time:1921ms step_avg:60.03ms step:33/2225 train_time:1981ms step_avg:60.03ms step:34/2225 train_time:2040ms step_avg:59.99ms step:35/2225 train_time:2101ms step_avg:60.03ms step:36/2225 train_time:2161ms step_avg:60.04ms step:37/2225 train_time:2223ms step_avg:60.08ms step:38/2225 train_time:2283ms step_avg:60.07ms step:39/2225 train_time:2344ms step_avg:60.11ms step:40/2225 train_time:2404ms step_avg:60.09ms step:41/2225 train_time:2465ms step_avg:60.11ms step:42/2225 train_time:2523ms step_avg:60.07ms step:43/2225 train_time:2584ms step_avg:60.08ms step:44/2225 train_time:2642ms step_avg:60.05ms step:45/2225 train_time:2702ms step_avg:60.05ms step:46/2225 train_time:2761ms step_avg:60.02ms step:47/2225 train_time:2822ms step_avg:60.04ms step:48/2225 train_time:2880ms step_avg:60.01ms step:49/2225 train_time:2941ms step_avg:60.01ms step:50/2225 train_time:2999ms step_avg:59.98ms step:51/2225 train_time:3060ms step_avg:60.00ms step:52/2225 train_time:3119ms step_avg:59.98ms step:53/2225 train_time:3180ms step_avg:60.00ms step:54/2225 train_time:3239ms step_avg:59.99ms step:55/2225 train_time:3300ms step_avg:60.01ms step:56/2225 train_time:3360ms step_avg:60.01ms step:57/2225 train_time:3421ms step_avg:60.02ms step:58/2225 train_time:3480ms step_avg:60.00ms step:59/2225 train_time:3541ms step_avg:60.02ms step:60/2225 train_time:3600ms step_avg:60.00ms step:61/2225 train_time:3661ms step_avg:60.02ms step:62/2225 train_time:3719ms step_avg:59.99ms step:63/2225 train_time:3780ms step_avg:60.00ms step:64/2225 train_time:3839ms step_avg:59.98ms step:65/2225 train_time:3899ms step_avg:59.99ms step:66/2225 train_time:3958ms step_avg:59.97ms step:67/2225 train_time:4019ms step_avg:59.99ms step:68/2225 train_time:4078ms step_avg:59.97ms step:69/2225 train_time:4139ms step_avg:59.98ms step:70/2225 train_time:4198ms step_avg:59.97ms step:71/2225 train_time:4260ms step_avg:60.00ms step:72/2225 train_time:4319ms step_avg:59.99ms step:73/2225 train_time:4380ms step_avg:60.00ms step:74/2225 train_time:4439ms step_avg:59.99ms step:75/2225 train_time:4500ms step_avg:60.00ms step:76/2225 train_time:4560ms step_avg:60.00ms step:77/2225 train_time:4621ms step_avg:60.01ms step:78/2225 train_time:4679ms step_avg:59.99ms step:79/2225 train_time:4740ms step_avg:60.00ms step:80/2225 train_time:4799ms step_avg:59.99ms step:81/2225 train_time:4860ms step_avg:60.00ms step:82/2225 train_time:4918ms step_avg:59.98ms step:83/2225 train_time:4979ms step_avg:59.99ms step:84/2225 train_time:5038ms step_avg:59.97ms step:85/2225 train_time:5099ms step_avg:59.98ms step:86/2225 train_time:5158ms step_avg:59.98ms step:87/2225 train_time:5220ms step_avg:60.00ms step:88/2225 train_time:5279ms step_avg:59.99ms step:89/2225 train_time:5340ms step_avg:60.00ms step:90/2225 train_time:5400ms step_avg:60.00ms step:91/2225 train_time:5460ms step_avg:60.00ms step:92/2225 train_time:5519ms step_avg:59.99ms step:93/2225 train_time:5580ms step_avg:60.00ms step:94/2225 train_time:5639ms step_avg:59.99ms step:95/2225 train_time:5700ms step_avg:60.00ms step:96/2225 train_time:5759ms step_avg:59.99ms step:97/2225 train_time:5819ms step_avg:59.99ms step:98/2225 train_time:5878ms step_avg:59.98ms step:99/2225 train_time:5938ms step_avg:59.98ms step:100/2225 train_time:5997ms step_avg:59.97ms step:101/2225 train_time:6058ms step_avg:59.98ms step:102/2225 train_time:6117ms step_avg:59.97ms step:103/2225 train_time:6177ms step_avg:59.97ms step:104/2225 train_time:6236ms step_avg:59.96ms step:105/2225 train_time:6297ms step_avg:59.97ms step:106/2225 train_time:6357ms step_avg:59.97ms step:107/2225 train_time:6418ms step_avg:59.98ms step:108/2225 train_time:6477ms step_avg:59.98ms step:109/2225 train_time:6538ms step_avg:59.98ms step:110/2225 train_time:6597ms step_avg:59.97ms step:111/2225 train_time:6658ms step_avg:59.98ms step:112/2225 train_time:6717ms step_avg:59.97ms step:113/2225 train_time:6778ms step_avg:59.98ms step:114/2225 train_time:6837ms step_avg:59.97ms step:115/2225 train_time:6897ms step_avg:59.98ms step:116/2225 train_time:6957ms step_avg:59.97ms step:117/2225 train_time:7017ms step_avg:59.97ms step:118/2225 train_time:7076ms step_avg:59.97ms step:119/2225 train_time:7137ms step_avg:59.97ms step:120/2225 train_time:7196ms step_avg:59.97ms step:121/2225 train_time:7257ms step_avg:59.98ms step:122/2225 train_time:7317ms step_avg:59.98ms step:123/2225 train_time:7378ms step_avg:59.98ms step:124/2225 train_time:7437ms step_avg:59.98ms step:125/2225 train_time:7498ms step_avg:59.99ms step:126/2225 train_time:7558ms step_avg:59.99ms step:127/2225 train_time:7619ms step_avg:59.99ms step:128/2225 train_time:7678ms step_avg:59.99ms step:129/2225 train_time:7739ms step_avg:59.99ms step:130/2225 train_time:7798ms step_avg:59.99ms step:131/2225 train_time:7859ms step_avg:59.99ms step:132/2225 train_time:7917ms step_avg:59.98ms step:133/2225 train_time:7978ms step_avg:59.99ms step:134/2225 train_time:8037ms step_avg:59.98ms step:135/2225 train_time:8098ms step_avg:59.98ms step:136/2225 train_time:8157ms step_avg:59.98ms step:137/2225 train_time:8218ms step_avg:59.99ms step:138/2225 train_time:8277ms step_avg:59.98ms step:139/2225 train_time:8338ms step_avg:59.98ms step:140/2225 train_time:8397ms step_avg:59.98ms step:141/2225 train_time:8458ms step_avg:59.99ms step:142/2225 train_time:8517ms step_avg:59.98ms step:143/2225 train_time:8578ms step_avg:59.99ms step:144/2225 train_time:8637ms step_avg:59.98ms step:145/2225 train_time:8698ms step_avg:59.99ms step:146/2225 train_time:8757ms step_avg:59.98ms step:147/2225 train_time:8818ms step_avg:59.99ms step:148/2225 train_time:8877ms step_avg:59.98ms step:149/2225 train_time:8937ms step_avg:59.98ms step:150/2225 train_time:8996ms step_avg:59.97ms step:151/2225 train_time:9056ms step_avg:59.97ms step:152/2225 train_time:9115ms step_avg:59.97ms step:153/2225 train_time:9176ms step_avg:59.98ms step:154/2225 train_time:9235ms step_avg:59.97ms step:155/2225 train_time:9296ms step_avg:59.97ms step:156/2225 train_time:9355ms step_avg:59.97ms step:157/2225 train_time:9416ms step_avg:59.98ms step:158/2225 train_time:9476ms step_avg:59.97ms step:159/2225 train_time:9537ms step_avg:59.98ms step:160/2225 train_time:9596ms step_avg:59.98ms step:161/2225 train_time:9657ms step_avg:59.98ms step:162/2225 train_time:9716ms step_avg:59.98ms step:163/2225 train_time:9777ms step_avg:59.98ms step:164/2225 train_time:9836ms step_avg:59.97ms step:165/2225 train_time:9897ms step_avg:59.98ms step:166/2225 train_time:9956ms step_avg:59.97ms step:167/2225 train_time:10016ms step_avg:59.98ms step:168/2225 train_time:10075ms step_avg:59.97ms step:169/2225 train_time:10136ms step_avg:59.97ms step:170/2225 train_time:10194ms step_avg:59.97ms step:171/2225 train_time:10255ms step_avg:59.97ms step:172/2225 train_time:10315ms step_avg:59.97ms step:173/2225 train_time:10375ms step_avg:59.97ms step:174/2225 train_time:10434ms step_avg:59.97ms step:175/2225 train_time:10495ms step_avg:59.97ms step:176/2225 train_time:10555ms step_avg:59.97ms step:177/2225 train_time:10616ms step_avg:59.98ms step:178/2225 train_time:10675ms step_avg:59.97ms step:179/2225 train_time:10735ms step_avg:59.97ms step:180/2225 train_time:10794ms step_avg:59.97ms step:181/2225 train_time:10855ms step_avg:59.97ms step:182/2225 train_time:10914ms step_avg:59.97ms step:183/2225 train_time:10975ms step_avg:59.97ms step:184/2225 train_time:11034ms step_avg:59.97ms step:185/2225 train_time:11094ms step_avg:59.97ms step:186/2225 train_time:11154ms step_avg:59.97ms step:187/2225 train_time:11214ms step_avg:59.97ms step:188/2225 train_time:11273ms step_avg:59.96ms step:189/2225 train_time:11334ms step_avg:59.97ms step:190/2225 train_time:11393ms step_avg:59.97ms step:191/2225 train_time:11455ms step_avg:59.97ms step:192/2225 train_time:11514ms step_avg:59.97ms step:193/2225 train_time:11575ms step_avg:59.97ms step:194/2225 train_time:11634ms step_avg:59.97ms step:195/2225 train_time:11694ms step_avg:59.97ms step:196/2225 train_time:11753ms step_avg:59.97ms step:197/2225 train_time:11814ms step_avg:59.97ms step:198/2225 train_time:11873ms step_avg:59.96ms step:199/2225 train_time:11934ms step_avg:59.97ms step:200/2225 train_time:11993ms step_avg:59.97ms step:201/2225 train_time:12054ms step_avg:59.97ms step:202/2225 train_time:12113ms step_avg:59.96ms step:203/2225 train_time:12173ms step_avg:59.97ms step:204/2225 train_time:12232ms step_avg:59.96ms step:205/2225 train_time:12293ms step_avg:59.97ms step:206/2225 train_time:12352ms step_avg:59.96ms step:207/2225 train_time:12413ms step_avg:59.97ms step:208/2225 train_time:12472ms step_avg:59.96ms step:209/2225 train_time:12534ms step_avg:59.97ms step:210/2225 train_time:12593ms step_avg:59.97ms step:211/2225 train_time:12654ms step_avg:59.97ms step:212/2225 train_time:12713ms step_avg:59.97ms step:213/2225 train_time:12773ms step_avg:59.97ms step:214/2225 train_time:12833ms step_avg:59.97ms step:215/2225 train_time:12893ms step_avg:59.97ms step:216/2225 train_time:12952ms step_avg:59.96ms step:217/2225 train_time:13013ms step_avg:59.97ms step:218/2225 train_time:13072ms step_avg:59.96ms step:219/2225 train_time:13133ms step_avg:59.97ms step:220/2225 train_time:13192ms step_avg:59.96ms step:221/2225 train_time:13253ms step_avg:59.97ms step:222/2225 train_time:13311ms step_avg:59.96ms step:223/2225 train_time:13373ms step_avg:59.97ms step:224/2225 train_time:13432ms step_avg:59.96ms step:225/2225 train_time:13493ms step_avg:59.97ms step:226/2225 train_time:13552ms step_avg:59.96ms step:227/2225 train_time:13613ms step_avg:59.97ms step:228/2225 train_time:13672ms step_avg:59.96ms step:229/2225 train_time:13732ms step_avg:59.97ms step:230/2225 train_time:13791ms step_avg:59.96ms step:231/2225 train_time:13852ms step_avg:59.97ms step:232/2225 train_time:13911ms step_avg:59.96ms step:233/2225 train_time:13972ms step_avg:59.96ms step:234/2225 train_time:14031ms step_avg:59.96ms step:235/2225 train_time:14091ms step_avg:59.96ms step:236/2225 train_time:14150ms step_avg:59.96ms step:237/2225 train_time:14211ms step_avg:59.96ms step:238/2225 train_time:14270ms step_avg:59.96ms step:239/2225 train_time:14331ms step_avg:59.96ms step:240/2225 train_time:14390ms step_avg:59.96ms step:241/2225 train_time:14451ms step_avg:59.96ms step:242/2225 train_time:14510ms step_avg:59.96ms step:243/2225 train_time:14571ms step_avg:59.96ms step:244/2225 train_time:14630ms step_avg:59.96ms step:245/2225 train_time:14690ms step_avg:59.96ms step:246/2225 train_time:14749ms step_avg:59.96ms step:247/2225 train_time:14810ms step_avg:59.96ms step:248/2225 train_time:14868ms step_avg:59.95ms step:249/2225 train_time:14929ms step_avg:59.96ms step:250/2225 train_time:14988ms step_avg:59.95ms step:250/2225 val_loss:4.0922 train_time:15050ms step_avg:60.20ms step:251/2225 train_time:15070ms step_avg:60.04ms step:252/2225 train_time:15110ms step_avg:59.96ms step:253/2225 train_time:15172ms step_avg:59.97ms step:254/2225 train_time:15237ms step_avg:59.99ms step:255/2225 train_time:15298ms step_avg:59.99ms step:256/2225 train_time:15356ms step_avg:59.99ms step:257/2225 train_time:15417ms step_avg:59.99ms step:258/2225 train_time:15474ms step_avg:59.98ms step:259/2225 train_time:15534ms step_avg:59.98ms step:260/2225 train_time:15592ms step_avg:59.97ms step:261/2225 train_time:15651ms step_avg:59.97ms step:262/2225 train_time:15709ms step_avg:59.96ms step:263/2225 train_time:15769ms step_avg:59.96ms step:264/2225 train_time:15826ms step_avg:59.95ms step:265/2225 train_time:15886ms step_avg:59.95ms step:266/2225 train_time:15945ms step_avg:59.94ms step:267/2225 train_time:16007ms step_avg:59.95ms step:268/2225 train_time:16067ms step_avg:59.95ms step:269/2225 train_time:16130ms step_avg:59.96ms step:270/2225 train_time:16191ms step_avg:59.96ms step:271/2225 train_time:16252ms step_avg:59.97ms step:272/2225 train_time:16310ms step_avg:59.96ms step:273/2225 train_time:16371ms step_avg:59.97ms step:274/2225 train_time:16430ms step_avg:59.96ms step:275/2225 train_time:16491ms step_avg:59.97ms step:276/2225 train_time:16549ms step_avg:59.96ms step:277/2225 train_time:16609ms step_avg:59.96ms step:278/2225 train_time:16667ms step_avg:59.95ms step:279/2225 train_time:16727ms step_avg:59.95ms step:280/2225 train_time:16785ms step_avg:59.95ms step:281/2225 train_time:16845ms step_avg:59.95ms step:282/2225 train_time:16904ms step_avg:59.94ms step:283/2225 train_time:16964ms step_avg:59.94ms step:284/2225 train_time:17024ms step_avg:59.94ms step:285/2225 train_time:17086ms step_avg:59.95ms step:286/2225 train_time:17145ms step_avg:59.95ms step:287/2225 train_time:17207ms step_avg:59.96ms step:288/2225 train_time:17267ms step_avg:59.95ms step:289/2225 train_time:17328ms step_avg:59.96ms step:290/2225 train_time:17388ms step_avg:59.96ms step:291/2225 train_time:17449ms step_avg:59.96ms step:292/2225 train_time:17507ms step_avg:59.96ms step:293/2225 train_time:17568ms step_avg:59.96ms step:294/2225 train_time:17627ms step_avg:59.95ms step:295/2225 train_time:17687ms step_avg:59.96ms step:296/2225 train_time:17745ms step_avg:59.95ms step:297/2225 train_time:17805ms step_avg:59.95ms step:298/2225 train_time:17863ms step_avg:59.94ms step:299/2225 train_time:17924ms step_avg:59.95ms step:300/2225 train_time:17983ms step_avg:59.94ms step:301/2225 train_time:18044ms step_avg:59.95ms step:302/2225 train_time:18103ms step_avg:59.94ms step:303/2225 train_time:18165ms step_avg:59.95ms step:304/2225 train_time:18225ms step_avg:59.95ms step:305/2225 train_time:18287ms step_avg:59.96ms step:306/2225 train_time:18346ms step_avg:59.96ms step:307/2225 train_time:18408ms step_avg:59.96ms step:308/2225 train_time:18466ms step_avg:59.96ms step:309/2225 train_time:18527ms step_avg:59.96ms step:310/2225 train_time:18586ms step_avg:59.96ms step:311/2225 train_time:18647ms step_avg:59.96ms step:312/2225 train_time:18705ms step_avg:59.95ms step:313/2225 train_time:18765ms step_avg:59.95ms step:314/2225 train_time:18823ms step_avg:59.95ms step:315/2225 train_time:18883ms step_avg:59.95ms step:316/2225 train_time:18942ms step_avg:59.94ms step:317/2225 train_time:19003ms step_avg:59.95ms step:318/2225 train_time:19063ms step_avg:59.95ms step:319/2225 train_time:19123ms step_avg:59.95ms step:320/2225 train_time:19184ms step_avg:59.95ms step:321/2225 train_time:19244ms step_avg:59.95ms step:322/2225 train_time:19303ms step_avg:59.95ms step:323/2225 train_time:19365ms step_avg:59.95ms step:324/2225 train_time:19424ms step_avg:59.95ms step:325/2225 train_time:19485ms step_avg:59.95ms step:326/2225 train_time:19544ms step_avg:59.95ms step:327/2225 train_time:19605ms step_avg:59.95ms step:328/2225 train_time:19663ms step_avg:59.95ms step:329/2225 train_time:19725ms step_avg:59.95ms step:330/2225 train_time:19782ms step_avg:59.95ms step:331/2225 train_time:19843ms step_avg:59.95ms step:332/2225 train_time:19901ms step_avg:59.94ms step:333/2225 train_time:19962ms step_avg:59.95ms step:334/2225 train_time:20021ms step_avg:59.94ms step:335/2225 train_time:20082ms step_avg:59.95ms step:336/2225 train_time:20141ms step_avg:59.94ms step:337/2225 train_time:20202ms step_avg:59.95ms step:338/2225 train_time:20261ms step_avg:59.94ms step:339/2225 train_time:20322ms step_avg:59.95ms step:340/2225 train_time:20382ms step_avg:59.95ms step:341/2225 train_time:20442ms step_avg:59.95ms step:342/2225 train_time:20501ms step_avg:59.94ms step:343/2225 train_time:20562ms step_avg:59.95ms step:344/2225 train_time:20621ms step_avg:59.94ms step:345/2225 train_time:20681ms step_avg:59.95ms step:346/2225 train_time:20740ms step_avg:59.94ms step:347/2225 train_time:20800ms step_avg:59.94ms step:348/2225 train_time:20858ms step_avg:59.94ms step:349/2225 train_time:20919ms step_avg:59.94ms step:350/2225 train_time:20977ms step_avg:59.94ms step:351/2225 train_time:21038ms step_avg:59.94ms step:352/2225 train_time:21096ms step_avg:59.93ms step:353/2225 train_time:21157ms step_avg:59.94ms step:354/2225 train_time:21215ms step_avg:59.93ms step:355/2225 train_time:21276ms step_avg:59.93ms step:356/2225 train_time:21335ms step_avg:59.93ms step:357/2225 train_time:21395ms step_avg:59.93ms step:358/2225 train_time:21454ms step_avg:59.93ms step:359/2225 train_time:21514ms step_avg:59.93ms step:360/2225 train_time:21573ms step_avg:59.92ms step:361/2225 train_time:21633ms step_avg:59.92ms step:362/2225 train_time:21691ms step_avg:59.92ms step:363/2225 train_time:21752ms step_avg:59.92ms step:364/2225 train_time:21811ms step_avg:59.92ms step:365/2225 train_time:21871ms step_avg:59.92ms step:366/2225 train_time:21930ms step_avg:59.92ms step:367/2225 train_time:21990ms step_avg:59.92ms step:368/2225 train_time:22049ms step_avg:59.92ms step:369/2225 train_time:22110ms step_avg:59.92ms step:370/2225 train_time:22168ms step_avg:59.91ms step:371/2225 train_time:22230ms step_avg:59.92ms step:372/2225 train_time:22289ms step_avg:59.92ms step:373/2225 train_time:22350ms step_avg:59.92ms step:374/2225 train_time:22408ms step_avg:59.92ms step:375/2225 train_time:22468ms step_avg:59.92ms step:376/2225 train_time:22527ms step_avg:59.91ms step:377/2225 train_time:22588ms step_avg:59.91ms step:378/2225 train_time:22647ms step_avg:59.91ms step:379/2225 train_time:22707ms step_avg:59.91ms step:380/2225 train_time:22766ms step_avg:59.91ms step:381/2225 train_time:22826ms step_avg:59.91ms step:382/2225 train_time:22885ms step_avg:59.91ms step:383/2225 train_time:22945ms step_avg:59.91ms step:384/2225 train_time:23004ms step_avg:59.91ms step:385/2225 train_time:23065ms step_avg:59.91ms step:386/2225 train_time:23124ms step_avg:59.91ms step:387/2225 train_time:23185ms step_avg:59.91ms step:388/2225 train_time:23244ms step_avg:59.91ms step:389/2225 train_time:23305ms step_avg:59.91ms step:390/2225 train_time:23364ms step_avg:59.91ms step:391/2225 train_time:23425ms step_avg:59.91ms step:392/2225 train_time:23484ms step_avg:59.91ms step:393/2225 train_time:23545ms step_avg:59.91ms step:394/2225 train_time:23604ms step_avg:59.91ms step:395/2225 train_time:23665ms step_avg:59.91ms step:396/2225 train_time:23724ms step_avg:59.91ms step:397/2225 train_time:23785ms step_avg:59.91ms step:398/2225 train_time:23844ms step_avg:59.91ms step:399/2225 train_time:23904ms step_avg:59.91ms step:400/2225 train_time:23963ms step_avg:59.91ms step:401/2225 train_time:24024ms step_avg:59.91ms step:402/2225 train_time:24083ms step_avg:59.91ms step:403/2225 train_time:24144ms step_avg:59.91ms step:404/2225 train_time:24203ms step_avg:59.91ms step:405/2225 train_time:24264ms step_avg:59.91ms step:406/2225 train_time:24322ms step_avg:59.91ms step:407/2225 train_time:24383ms step_avg:59.91ms step:408/2225 train_time:24442ms step_avg:59.91ms step:409/2225 train_time:24503ms step_avg:59.91ms step:410/2225 train_time:24562ms step_avg:59.91ms step:411/2225 train_time:24622ms step_avg:59.91ms step:412/2225 train_time:24681ms step_avg:59.91ms step:413/2225 train_time:24742ms step_avg:59.91ms step:414/2225 train_time:24801ms step_avg:59.91ms step:415/2225 train_time:24862ms step_avg:59.91ms step:416/2225 train_time:24920ms step_avg:59.90ms step:417/2225 train_time:24981ms step_avg:59.91ms step:418/2225 train_time:25040ms step_avg:59.90ms step:419/2225 train_time:25101ms step_avg:59.91ms step:420/2225 train_time:25160ms step_avg:59.90ms step:421/2225 train_time:25221ms step_avg:59.91ms step:422/2225 train_time:25280ms step_avg:59.90ms step:423/2225 train_time:25340ms step_avg:59.91ms step:424/2225 train_time:25399ms step_avg:59.90ms step:425/2225 train_time:25460ms step_avg:59.91ms step:426/2225 train_time:25519ms step_avg:59.90ms step:427/2225 train_time:25579ms step_avg:59.91ms step:428/2225 train_time:25638ms step_avg:59.90ms step:429/2225 train_time:25699ms step_avg:59.91ms step:430/2225 train_time:25758ms step_avg:59.90ms step:431/2225 train_time:25818ms step_avg:59.90ms step:432/2225 train_time:25877ms step_avg:59.90ms step:433/2225 train_time:25937ms step_avg:59.90ms step:434/2225 train_time:25996ms step_avg:59.90ms step:435/2225 train_time:26057ms step_avg:59.90ms step:436/2225 train_time:26115ms step_avg:59.90ms step:437/2225 train_time:26176ms step_avg:59.90ms step:438/2225 train_time:26235ms step_avg:59.90ms step:439/2225 train_time:26295ms step_avg:59.90ms step:440/2225 train_time:26354ms step_avg:59.90ms step:441/2225 train_time:26414ms step_avg:59.89ms step:442/2225 train_time:26472ms step_avg:59.89ms step:443/2225 train_time:26532ms step_avg:59.89ms step:444/2225 train_time:26590ms step_avg:59.89ms step:445/2225 train_time:26651ms step_avg:59.89ms step:446/2225 train_time:26709ms step_avg:59.89ms step:447/2225 train_time:26769ms step_avg:59.89ms step:448/2225 train_time:26828ms step_avg:59.88ms step:449/2225 train_time:26888ms step_avg:59.89ms step:450/2225 train_time:26948ms step_avg:59.88ms step:451/2225 train_time:27009ms step_avg:59.89ms step:452/2225 train_time:27068ms step_avg:59.88ms step:453/2225 train_time:27129ms step_avg:59.89ms step:454/2225 train_time:27188ms step_avg:59.89ms step:455/2225 train_time:27249ms step_avg:59.89ms step:456/2225 train_time:27308ms step_avg:59.89ms step:457/2225 train_time:27368ms step_avg:59.89ms step:458/2225 train_time:27427ms step_avg:59.89ms step:459/2225 train_time:27488ms step_avg:59.89ms step:460/2225 train_time:27547ms step_avg:59.88ms step:461/2225 train_time:27607ms step_avg:59.88ms step:462/2225 train_time:27665ms step_avg:59.88ms step:463/2225 train_time:27726ms step_avg:59.88ms step:464/2225 train_time:27784ms step_avg:59.88ms step:465/2225 train_time:27845ms step_avg:59.88ms step:466/2225 train_time:27903ms step_avg:59.88ms step:467/2225 train_time:27965ms step_avg:59.88ms step:468/2225 train_time:28024ms step_avg:59.88ms step:469/2225 train_time:28085ms step_avg:59.88ms step:470/2225 train_time:28144ms step_avg:59.88ms step:471/2225 train_time:28205ms step_avg:59.88ms step:472/2225 train_time:28265ms step_avg:59.88ms step:473/2225 train_time:28325ms step_avg:59.88ms step:474/2225 train_time:28384ms step_avg:59.88ms step:475/2225 train_time:28446ms step_avg:59.89ms step:476/2225 train_time:28505ms step_avg:59.89ms step:477/2225 train_time:28566ms step_avg:59.89ms step:478/2225 train_time:28625ms step_avg:59.88ms step:479/2225 train_time:28685ms step_avg:59.89ms step:480/2225 train_time:28744ms step_avg:59.88ms step:481/2225 train_time:28805ms step_avg:59.89ms step:482/2225 train_time:28864ms step_avg:59.88ms step:483/2225 train_time:28925ms step_avg:59.89ms step:484/2225 train_time:28983ms step_avg:59.88ms step:485/2225 train_time:29044ms step_avg:59.88ms step:486/2225 train_time:29103ms step_avg:59.88ms step:487/2225 train_time:29164ms step_avg:59.89ms step:488/2225 train_time:29224ms step_avg:59.88ms step:489/2225 train_time:29284ms step_avg:59.89ms step:490/2225 train_time:29343ms step_avg:59.88ms step:491/2225 train_time:29404ms step_avg:59.89ms step:492/2225 train_time:29464ms step_avg:59.89ms step:493/2225 train_time:29524ms step_avg:59.89ms step:494/2225 train_time:29584ms step_avg:59.89ms step:495/2225 train_time:29645ms step_avg:59.89ms step:496/2225 train_time:29704ms step_avg:59.89ms step:497/2225 train_time:29764ms step_avg:59.89ms step:498/2225 train_time:29823ms step_avg:59.89ms step:499/2225 train_time:29884ms step_avg:59.89ms step:500/2225 train_time:29943ms step_avg:59.89ms step:500/2225 val_loss:3.8204 train_time:30005ms step_avg:60.01ms step:501/2225 train_time:30028ms step_avg:59.94ms step:502/2225 train_time:30064ms step_avg:59.89ms step:503/2225 train_time:30130ms step_avg:59.90ms step:504/2225 train_time:30195ms step_avg:59.91ms step:505/2225 train_time:30255ms step_avg:59.91ms step:506/2225 train_time:30313ms step_avg:59.91ms step:507/2225 train_time:30373ms step_avg:59.91ms step:508/2225 train_time:30431ms step_avg:59.90ms step:509/2225 train_time:30491ms step_avg:59.90ms step:510/2225 train_time:30549ms step_avg:59.90ms step:511/2225 train_time:30608ms step_avg:59.90ms step:512/2225 train_time:30666ms step_avg:59.89ms step:513/2225 train_time:30726ms step_avg:59.89ms step:514/2225 train_time:30784ms step_avg:59.89ms step:515/2225 train_time:30843ms step_avg:59.89ms step:516/2225 train_time:30901ms step_avg:59.89ms step:517/2225 train_time:30963ms step_avg:59.89ms step:518/2225 train_time:31023ms step_avg:59.89ms step:519/2225 train_time:31087ms step_avg:59.90ms step:520/2225 train_time:31147ms step_avg:59.90ms step:521/2225 train_time:31209ms step_avg:59.90ms step:522/2225 train_time:31269ms step_avg:59.90ms step:523/2225 train_time:31330ms step_avg:59.90ms step:524/2225 train_time:31389ms step_avg:59.90ms step:525/2225 train_time:31449ms step_avg:59.90ms step:526/2225 train_time:31508ms step_avg:59.90ms step:527/2225 train_time:31568ms step_avg:59.90ms step:528/2225 train_time:31627ms step_avg:59.90ms step:529/2225 train_time:31687ms step_avg:59.90ms step:530/2225 train_time:31745ms step_avg:59.90ms step:531/2225 train_time:31805ms step_avg:59.90ms step:532/2225 train_time:31864ms step_avg:59.89ms step:533/2225 train_time:31925ms step_avg:59.90ms step:534/2225 train_time:31984ms step_avg:59.90ms step:535/2225 train_time:32046ms step_avg:59.90ms step:536/2225 train_time:32106ms step_avg:59.90ms step:537/2225 train_time:32168ms step_avg:59.90ms step:538/2225 train_time:32228ms step_avg:59.90ms step:539/2225 train_time:32289ms step_avg:59.91ms step:540/2225 train_time:32348ms step_avg:59.90ms step:541/2225 train_time:32409ms step_avg:59.91ms step:542/2225 train_time:32468ms step_avg:59.90ms step:543/2225 train_time:32529ms step_avg:59.91ms step:544/2225 train_time:32587ms step_avg:59.90ms step:545/2225 train_time:32648ms step_avg:59.90ms step:546/2225 train_time:32706ms step_avg:59.90ms step:547/2225 train_time:32766ms step_avg:59.90ms step:548/2225 train_time:32824ms step_avg:59.90ms step:549/2225 train_time:32885ms step_avg:59.90ms step:550/2225 train_time:32944ms step_avg:59.90ms step:551/2225 train_time:33005ms step_avg:59.90ms step:552/2225 train_time:33064ms step_avg:59.90ms step:553/2225 train_time:33126ms step_avg:59.90ms step:554/2225 train_time:33185ms step_avg:59.90ms step:555/2225 train_time:33246ms step_avg:59.90ms step:556/2225 train_time:33306ms step_avg:59.90ms step:557/2225 train_time:33367ms step_avg:59.90ms step:558/2225 train_time:33426ms step_avg:59.90ms step:559/2225 train_time:33487ms step_avg:59.91ms step:560/2225 train_time:33546ms step_avg:59.90ms step:561/2225 train_time:33606ms step_avg:59.90ms step:562/2225 train_time:33665ms step_avg:59.90ms step:563/2225 train_time:33726ms step_avg:59.90ms step:564/2225 train_time:33785ms step_avg:59.90ms step:565/2225 train_time:33846ms step_avg:59.90ms step:566/2225 train_time:33904ms step_avg:59.90ms step:567/2225 train_time:33965ms step_avg:59.90ms step:568/2225 train_time:34025ms step_avg:59.90ms step:569/2225 train_time:34086ms step_avg:59.91ms step:570/2225 train_time:34146ms step_avg:59.91ms step:571/2225 train_time:34208ms step_avg:59.91ms step:572/2225 train_time:34267ms step_avg:59.91ms step:573/2225 train_time:34329ms step_avg:59.91ms step:574/2225 train_time:34388ms step_avg:59.91ms step:575/2225 train_time:34449ms step_avg:59.91ms step:576/2225 train_time:34507ms step_avg:59.91ms step:577/2225 train_time:34567ms step_avg:59.91ms step:578/2225 train_time:34627ms step_avg:59.91ms step:579/2225 train_time:34687ms step_avg:59.91ms step:580/2225 train_time:34746ms step_avg:59.91ms step:581/2225 train_time:34806ms step_avg:59.91ms step:582/2225 train_time:34865ms step_avg:59.91ms step:583/2225 train_time:34926ms step_avg:59.91ms step:584/2225 train_time:34985ms step_avg:59.91ms step:585/2225 train_time:35046ms step_avg:59.91ms step:586/2225 train_time:35105ms step_avg:59.91ms step:587/2225 train_time:35166ms step_avg:59.91ms step:588/2225 train_time:35226ms step_avg:59.91ms step:589/2225 train_time:35287ms step_avg:59.91ms step:590/2225 train_time:35346ms step_avg:59.91ms step:591/2225 train_time:35407ms step_avg:59.91ms step:592/2225 train_time:35466ms step_avg:59.91ms step:593/2225 train_time:35527ms step_avg:59.91ms step:594/2225 train_time:35586ms step_avg:59.91ms step:595/2225 train_time:35646ms step_avg:59.91ms step:596/2225 train_time:35705ms step_avg:59.91ms step:597/2225 train_time:35765ms step_avg:59.91ms step:598/2225 train_time:35824ms step_avg:59.91ms step:599/2225 train_time:35885ms step_avg:59.91ms step:600/2225 train_time:35943ms step_avg:59.91ms step:601/2225 train_time:36004ms step_avg:59.91ms step:602/2225 train_time:36063ms step_avg:59.91ms step:603/2225 train_time:36125ms step_avg:59.91ms step:604/2225 train_time:36184ms step_avg:59.91ms step:605/2225 train_time:36245ms step_avg:59.91ms step:606/2225 train_time:36304ms step_avg:59.91ms step:607/2225 train_time:36366ms step_avg:59.91ms step:608/2225 train_time:36426ms step_avg:59.91ms step:609/2225 train_time:36487ms step_avg:59.91ms step:610/2225 train_time:36545ms step_avg:59.91ms step:611/2225 train_time:36606ms step_avg:59.91ms step:612/2225 train_time:36665ms step_avg:59.91ms step:613/2225 train_time:36725ms step_avg:59.91ms step:614/2225 train_time:36783ms step_avg:59.91ms step:615/2225 train_time:36844ms step_avg:59.91ms step:616/2225 train_time:36903ms step_avg:59.91ms step:617/2225 train_time:36964ms step_avg:59.91ms step:618/2225 train_time:37023ms step_avg:59.91ms step:619/2225 train_time:37084ms step_avg:59.91ms step:620/2225 train_time:37144ms step_avg:59.91ms step:621/2225 train_time:37205ms step_avg:59.91ms step:622/2225 train_time:37265ms step_avg:59.91ms step:623/2225 train_time:37326ms step_avg:59.91ms step:624/2225 train_time:37385ms step_avg:59.91ms step:625/2225 train_time:37446ms step_avg:59.91ms step:626/2225 train_time:37505ms step_avg:59.91ms step:627/2225 train_time:37566ms step_avg:59.91ms step:628/2225 train_time:37625ms step_avg:59.91ms step:629/2225 train_time:37686ms step_avg:59.91ms step:630/2225 train_time:37745ms step_avg:59.91ms step:631/2225 train_time:37805ms step_avg:59.91ms step:632/2225 train_time:37864ms step_avg:59.91ms step:633/2225 train_time:37924ms step_avg:59.91ms step:634/2225 train_time:37983ms step_avg:59.91ms step:635/2225 train_time:38044ms step_avg:59.91ms step:636/2225 train_time:38103ms step_avg:59.91ms step:637/2225 train_time:38164ms step_avg:59.91ms step:638/2225 train_time:38223ms step_avg:59.91ms step:639/2225 train_time:38284ms step_avg:59.91ms step:640/2225 train_time:38343ms step_avg:59.91ms step:641/2225 train_time:38404ms step_avg:59.91ms step:642/2225 train_time:38463ms step_avg:59.91ms step:643/2225 train_time:38524ms step_avg:59.91ms step:644/2225 train_time:38583ms step_avg:59.91ms step:645/2225 train_time:38644ms step_avg:59.91ms step:646/2225 train_time:38703ms step_avg:59.91ms step:647/2225 train_time:38764ms step_avg:59.91ms step:648/2225 train_time:38823ms step_avg:59.91ms step:649/2225 train_time:38883ms step_avg:59.91ms step:650/2225 train_time:38943ms step_avg:59.91ms step:651/2225 train_time:39003ms step_avg:59.91ms step:652/2225 train_time:39062ms step_avg:59.91ms step:653/2225 train_time:39123ms step_avg:59.91ms step:654/2225 train_time:39182ms step_avg:59.91ms step:655/2225 train_time:39243ms step_avg:59.91ms step:656/2225 train_time:39302ms step_avg:59.91ms step:657/2225 train_time:39363ms step_avg:59.91ms step:658/2225 train_time:39422ms step_avg:59.91ms step:659/2225 train_time:39483ms step_avg:59.91ms step:660/2225 train_time:39542ms step_avg:59.91ms step:661/2225 train_time:39603ms step_avg:59.91ms step:662/2225 train_time:39662ms step_avg:59.91ms step:663/2225 train_time:39722ms step_avg:59.91ms step:664/2225 train_time:39781ms step_avg:59.91ms step:665/2225 train_time:39842ms step_avg:59.91ms step:666/2225 train_time:39901ms step_avg:59.91ms step:667/2225 train_time:39962ms step_avg:59.91ms step:668/2225 train_time:40020ms step_avg:59.91ms step:669/2225 train_time:40081ms step_avg:59.91ms step:670/2225 train_time:40140ms step_avg:59.91ms step:671/2225 train_time:40201ms step_avg:59.91ms step:672/2225 train_time:40260ms step_avg:59.91ms step:673/2225 train_time:40321ms step_avg:59.91ms step:674/2225 train_time:40379ms step_avg:59.91ms step:675/2225 train_time:40441ms step_avg:59.91ms step:676/2225 train_time:40500ms step_avg:59.91ms step:677/2225 train_time:40560ms step_avg:59.91ms step:678/2225 train_time:40619ms step_avg:59.91ms step:679/2225 train_time:40679ms step_avg:59.91ms step:680/2225 train_time:40738ms step_avg:59.91ms step:681/2225 train_time:40799ms step_avg:59.91ms step:682/2225 train_time:40858ms step_avg:59.91ms step:683/2225 train_time:40918ms step_avg:59.91ms step:684/2225 train_time:40977ms step_avg:59.91ms step:685/2225 train_time:41037ms step_avg:59.91ms step:686/2225 train_time:41096ms step_avg:59.91ms step:687/2225 train_time:41156ms step_avg:59.91ms step:688/2225 train_time:41215ms step_avg:59.91ms step:689/2225 train_time:41276ms step_avg:59.91ms step:690/2225 train_time:41334ms step_avg:59.90ms step:691/2225 train_time:41394ms step_avg:59.90ms step:692/2225 train_time:41452ms step_avg:59.90ms step:693/2225 train_time:41512ms step_avg:59.90ms step:694/2225 train_time:41571ms step_avg:59.90ms step:695/2225 train_time:41631ms step_avg:59.90ms step:696/2225 train_time:41690ms step_avg:59.90ms step:697/2225 train_time:41751ms step_avg:59.90ms step:698/2225 train_time:41809ms step_avg:59.90ms step:699/2225 train_time:41871ms step_avg:59.90ms step:700/2225 train_time:41930ms step_avg:59.90ms step:701/2225 train_time:41990ms step_avg:59.90ms step:702/2225 train_time:42049ms step_avg:59.90ms step:703/2225 train_time:42109ms step_avg:59.90ms step:704/2225 train_time:42168ms step_avg:59.90ms step:705/2225 train_time:42229ms step_avg:59.90ms step:706/2225 train_time:42288ms step_avg:59.90ms step:707/2225 train_time:42348ms step_avg:59.90ms step:708/2225 train_time:42407ms step_avg:59.90ms step:709/2225 train_time:42468ms step_avg:59.90ms step:710/2225 train_time:42527ms step_avg:59.90ms step:711/2225 train_time:42588ms step_avg:59.90ms step:712/2225 train_time:42647ms step_avg:59.90ms step:713/2225 train_time:42708ms step_avg:59.90ms step:714/2225 train_time:42767ms step_avg:59.90ms step:715/2225 train_time:42828ms step_avg:59.90ms step:716/2225 train_time:42887ms step_avg:59.90ms step:717/2225 train_time:42948ms step_avg:59.90ms step:718/2225 train_time:43006ms step_avg:59.90ms step:719/2225 train_time:43068ms step_avg:59.90ms step:720/2225 train_time:43127ms step_avg:59.90ms step:721/2225 train_time:43188ms step_avg:59.90ms step:722/2225 train_time:43246ms step_avg:59.90ms step:723/2225 train_time:43307ms step_avg:59.90ms step:724/2225 train_time:43366ms step_avg:59.90ms step:725/2225 train_time:43427ms step_avg:59.90ms step:726/2225 train_time:43486ms step_avg:59.90ms step:727/2225 train_time:43546ms step_avg:59.90ms step:728/2225 train_time:43605ms step_avg:59.90ms step:729/2225 train_time:43665ms step_avg:59.90ms step:730/2225 train_time:43725ms step_avg:59.90ms step:731/2225 train_time:43787ms step_avg:59.90ms step:732/2225 train_time:43846ms step_avg:59.90ms step:733/2225 train_time:43908ms step_avg:59.90ms step:734/2225 train_time:43967ms step_avg:59.90ms step:735/2225 train_time:44030ms step_avg:59.90ms step:736/2225 train_time:44090ms step_avg:59.90ms step:737/2225 train_time:44151ms step_avg:59.91ms step:738/2225 train_time:44210ms step_avg:59.91ms step:739/2225 train_time:44272ms step_avg:59.91ms step:740/2225 train_time:44331ms step_avg:59.91ms step:741/2225 train_time:44393ms step_avg:59.91ms step:742/2225 train_time:44452ms step_avg:59.91ms step:743/2225 train_time:44514ms step_avg:59.91ms step:744/2225 train_time:44573ms step_avg:59.91ms step:745/2225 train_time:44634ms step_avg:59.91ms step:746/2225 train_time:44694ms step_avg:59.91ms step:747/2225 train_time:44754ms step_avg:59.91ms step:748/2225 train_time:44813ms step_avg:59.91ms step:749/2225 train_time:44874ms step_avg:59.91ms step:750/2225 train_time:44934ms step_avg:59.91ms step:750/2225 val_loss:3.6685 train_time:44995ms step_avg:59.99ms step:751/2225 train_time:45020ms step_avg:59.95ms step:752/2225 train_time:45056ms step_avg:59.91ms step:753/2225 train_time:45118ms step_avg:59.92ms step:754/2225 train_time:45181ms step_avg:59.92ms step:755/2225 train_time:45242ms step_avg:59.92ms step:756/2225 train_time:45301ms step_avg:59.92ms step:757/2225 train_time:45362ms step_avg:59.92ms step:758/2225 train_time:45420ms step_avg:59.92ms step:759/2225 train_time:45480ms step_avg:59.92ms step:760/2225 train_time:45539ms step_avg:59.92ms step:761/2225 train_time:45599ms step_avg:59.92ms step:762/2225 train_time:45658ms step_avg:59.92ms step:763/2225 train_time:45718ms step_avg:59.92ms step:764/2225 train_time:45777ms step_avg:59.92ms step:765/2225 train_time:45837ms step_avg:59.92ms step:766/2225 train_time:45902ms step_avg:59.92ms step:767/2225 train_time:45971ms step_avg:59.94ms step:768/2225 train_time:46030ms step_avg:59.94ms step:769/2225 train_time:46092ms step_avg:59.94ms step:770/2225 train_time:46151ms step_avg:59.94ms step:771/2225 train_time:46212ms step_avg:59.94ms step:772/2225 train_time:46272ms step_avg:59.94ms step:773/2225 train_time:46332ms step_avg:59.94ms step:774/2225 train_time:46392ms step_avg:59.94ms step:775/2225 train_time:46452ms step_avg:59.94ms step:776/2225 train_time:46511ms step_avg:59.94ms step:777/2225 train_time:46571ms step_avg:59.94ms step:778/2225 train_time:46630ms step_avg:59.94ms step:779/2225 train_time:46691ms step_avg:59.94ms step:780/2225 train_time:46749ms step_avg:59.94ms step:781/2225 train_time:46810ms step_avg:59.94ms step:782/2225 train_time:46871ms step_avg:59.94ms step:783/2225 train_time:46933ms step_avg:59.94ms step:784/2225 train_time:46992ms step_avg:59.94ms step:785/2225 train_time:47054ms step_avg:59.94ms step:786/2225 train_time:47114ms step_avg:59.94ms step:787/2225 train_time:47175ms step_avg:59.94ms step:788/2225 train_time:47235ms step_avg:59.94ms step:789/2225 train_time:47297ms step_avg:59.95ms step:790/2225 train_time:47356ms step_avg:59.94ms step:791/2225 train_time:47417ms step_avg:59.95ms step:792/2225 train_time:47476ms step_avg:59.94ms step:793/2225 train_time:47537ms step_avg:59.95ms step:794/2225 train_time:47596ms step_avg:59.94ms step:795/2225 train_time:47657ms step_avg:59.95ms step:796/2225 train_time:47717ms step_avg:59.95ms step:797/2225 train_time:47778ms step_avg:59.95ms step:798/2225 train_time:47838ms step_avg:59.95ms step:799/2225 train_time:47900ms step_avg:59.95ms step:800/2225 train_time:47961ms step_avg:59.95ms step:801/2225 train_time:48022ms step_avg:59.95ms step:802/2225 train_time:48082ms step_avg:59.95ms step:803/2225 train_time:48143ms step_avg:59.95ms step:804/2225 train_time:48203ms step_avg:59.95ms step:805/2225 train_time:48264ms step_avg:59.96ms step:806/2225 train_time:48323ms step_avg:59.95ms step:807/2225 train_time:48384ms step_avg:59.95ms step:808/2225 train_time:48442ms step_avg:59.95ms step:809/2225 train_time:48503ms step_avg:59.95ms step:810/2225 train_time:48562ms step_avg:59.95ms step:811/2225 train_time:48624ms step_avg:59.96ms step:812/2225 train_time:48682ms step_avg:59.95ms step:813/2225 train_time:48743ms step_avg:59.95ms step:814/2225 train_time:48803ms step_avg:59.95ms step:815/2225 train_time:48865ms step_avg:59.96ms step:816/2225 train_time:48923ms step_avg:59.96ms step:817/2225 train_time:48984ms step_avg:59.96ms step:818/2225 train_time:49043ms step_avg:59.96ms step:819/2225 train_time:49105ms step_avg:59.96ms step:820/2225 train_time:49163ms step_avg:59.96ms step:821/2225 train_time:49224ms step_avg:59.96ms step:822/2225 train_time:49284ms step_avg:59.96ms step:823/2225 train_time:49344ms step_avg:59.96ms step:824/2225 train_time:49403ms step_avg:59.96ms step:825/2225 train_time:49464ms step_avg:59.96ms step:826/2225 train_time:49523ms step_avg:59.96ms step:827/2225 train_time:49584ms step_avg:59.96ms step:828/2225 train_time:49643ms step_avg:59.96ms step:829/2225 train_time:49704ms step_avg:59.96ms step:830/2225 train_time:49763ms step_avg:59.96ms step:831/2225 train_time:49823ms step_avg:59.96ms step:832/2225 train_time:49883ms step_avg:59.96ms step:833/2225 train_time:49944ms step_avg:59.96ms step:834/2225 train_time:50003ms step_avg:59.96ms step:835/2225 train_time:50064ms step_avg:59.96ms step:836/2225 train_time:50123ms step_avg:59.96ms step:837/2225 train_time:50184ms step_avg:59.96ms step:838/2225 train_time:50243ms step_avg:59.96ms step:839/2225 train_time:50304ms step_avg:59.96ms step:840/2225 train_time:50364ms step_avg:59.96ms step:841/2225 train_time:50425ms step_avg:59.96ms step:842/2225 train_time:50484ms step_avg:59.96ms step:843/2225 train_time:50544ms step_avg:59.96ms step:844/2225 train_time:50604ms step_avg:59.96ms step:845/2225 train_time:50665ms step_avg:59.96ms step:846/2225 train_time:50724ms step_avg:59.96ms step:847/2225 train_time:50785ms step_avg:59.96ms step:848/2225 train_time:50844ms step_avg:59.96ms step:849/2225 train_time:50905ms step_avg:59.96ms step:850/2225 train_time:50964ms step_avg:59.96ms step:851/2225 train_time:51026ms step_avg:59.96ms step:852/2225 train_time:51085ms step_avg:59.96ms step:853/2225 train_time:51145ms step_avg:59.96ms step:854/2225 train_time:51204ms step_avg:59.96ms step:855/2225 train_time:51265ms step_avg:59.96ms step:856/2225 train_time:51324ms step_avg:59.96ms step:857/2225 train_time:51385ms step_avg:59.96ms step:858/2225 train_time:51445ms step_avg:59.96ms step:859/2225 train_time:51505ms step_avg:59.96ms step:860/2225 train_time:51565ms step_avg:59.96ms step:861/2225 train_time:51626ms step_avg:59.96ms step:862/2225 train_time:51685ms step_avg:59.96ms step:863/2225 train_time:51746ms step_avg:59.96ms step:864/2225 train_time:51805ms step_avg:59.96ms step:865/2225 train_time:51866ms step_avg:59.96ms step:866/2225 train_time:51925ms step_avg:59.96ms step:867/2225 train_time:51986ms step_avg:59.96ms step:868/2225 train_time:52045ms step_avg:59.96ms step:869/2225 train_time:52106ms step_avg:59.96ms step:870/2225 train_time:52166ms step_avg:59.96ms step:871/2225 train_time:52226ms step_avg:59.96ms step:872/2225 train_time:52285ms step_avg:59.96ms step:873/2225 train_time:52345ms step_avg:59.96ms step:874/2225 train_time:52405ms step_avg:59.96ms step:875/2225 train_time:52466ms step_avg:59.96ms step:876/2225 train_time:52525ms step_avg:59.96ms step:877/2225 train_time:52586ms step_avg:59.96ms step:878/2225 train_time:52644ms step_avg:59.96ms step:879/2225 train_time:52705ms step_avg:59.96ms step:880/2225 train_time:52764ms step_avg:59.96ms step:881/2225 train_time:52825ms step_avg:59.96ms step:882/2225 train_time:52884ms step_avg:59.96ms step:883/2225 train_time:52944ms step_avg:59.96ms step:884/2225 train_time:53003ms step_avg:59.96ms step:885/2225 train_time:53065ms step_avg:59.96ms step:886/2225 train_time:53124ms step_avg:59.96ms step:887/2225 train_time:53185ms step_avg:59.96ms step:888/2225 train_time:53244ms step_avg:59.96ms step:889/2225 train_time:53305ms step_avg:59.96ms step:890/2225 train_time:53365ms step_avg:59.96ms step:891/2225 train_time:53425ms step_avg:59.96ms step:892/2225 train_time:53484ms step_avg:59.96ms step:893/2225 train_time:53545ms step_avg:59.96ms step:894/2225 train_time:53605ms step_avg:59.96ms step:895/2225 train_time:53666ms step_avg:59.96ms step:896/2225 train_time:53725ms step_avg:59.96ms step:897/2225 train_time:53786ms step_avg:59.96ms step:898/2225 train_time:53846ms step_avg:59.96ms step:899/2225 train_time:53906ms step_avg:59.96ms step:900/2225 train_time:53966ms step_avg:59.96ms step:901/2225 train_time:54027ms step_avg:59.96ms step:902/2225 train_time:54086ms step_avg:59.96ms step:903/2225 train_time:54146ms step_avg:59.96ms step:904/2225 train_time:54205ms step_avg:59.96ms step:905/2225 train_time:54267ms step_avg:59.96ms step:906/2225 train_time:54326ms step_avg:59.96ms step:907/2225 train_time:54386ms step_avg:59.96ms step:908/2225 train_time:54445ms step_avg:59.96ms step:909/2225 train_time:54506ms step_avg:59.96ms step:910/2225 train_time:54565ms step_avg:59.96ms step:911/2225 train_time:54626ms step_avg:59.96ms step:912/2225 train_time:54685ms step_avg:59.96ms step:913/2225 train_time:54746ms step_avg:59.96ms step:914/2225 train_time:54805ms step_avg:59.96ms step:915/2225 train_time:54866ms step_avg:59.96ms step:916/2225 train_time:54925ms step_avg:59.96ms step:917/2225 train_time:54986ms step_avg:59.96ms step:918/2225 train_time:55045ms step_avg:59.96ms step:919/2225 train_time:55106ms step_avg:59.96ms step:920/2225 train_time:55165ms step_avg:59.96ms step:921/2225 train_time:55226ms step_avg:59.96ms step:922/2225 train_time:55285ms step_avg:59.96ms step:923/2225 train_time:55346ms step_avg:59.96ms step:924/2225 train_time:55405ms step_avg:59.96ms step:925/2225 train_time:55466ms step_avg:59.96ms step:926/2225 train_time:55525ms step_avg:59.96ms step:927/2225 train_time:55585ms step_avg:59.96ms step:928/2225 train_time:55644ms step_avg:59.96ms step:929/2225 train_time:55705ms step_avg:59.96ms step:930/2225 train_time:55764ms step_avg:59.96ms step:931/2225 train_time:55825ms step_avg:59.96ms step:932/2225 train_time:55884ms step_avg:59.96ms step:933/2225 train_time:55945ms step_avg:59.96ms step:934/2225 train_time:56004ms step_avg:59.96ms step:935/2225 train_time:56065ms step_avg:59.96ms step:936/2225 train_time:56124ms step_avg:59.96ms step:937/2225 train_time:56185ms step_avg:59.96ms step:938/2225 train_time:56244ms step_avg:59.96ms step:939/2225 train_time:56305ms step_avg:59.96ms step:940/2225 train_time:56364ms step_avg:59.96ms step:941/2225 train_time:56425ms step_avg:59.96ms step:942/2225 train_time:56485ms step_avg:59.96ms step:943/2225 train_time:56545ms step_avg:59.96ms step:944/2225 train_time:56604ms step_avg:59.96ms step:945/2225 train_time:56666ms step_avg:59.96ms step:946/2225 train_time:56725ms step_avg:59.96ms step:947/2225 train_time:56785ms step_avg:59.96ms step:948/2225 train_time:56844ms step_avg:59.96ms step:949/2225 train_time:56905ms step_avg:59.96ms step:950/2225 train_time:56965ms step_avg:59.96ms step:951/2225 train_time:57025ms step_avg:59.96ms step:952/2225 train_time:57084ms step_avg:59.96ms step:953/2225 train_time:57145ms step_avg:59.96ms step:954/2225 train_time:57205ms step_avg:59.96ms step:955/2225 train_time:57265ms step_avg:59.96ms step:956/2225 train_time:57325ms step_avg:59.96ms step:957/2225 train_time:57386ms step_avg:59.96ms step:958/2225 train_time:57445ms step_avg:59.96ms step:959/2225 train_time:57506ms step_avg:59.96ms step:960/2225 train_time:57565ms step_avg:59.96ms step:961/2225 train_time:57626ms step_avg:59.96ms step:962/2225 train_time:57685ms step_avg:59.96ms step:963/2225 train_time:57746ms step_avg:59.96ms step:964/2225 train_time:57805ms step_avg:59.96ms step:965/2225 train_time:57867ms step_avg:59.97ms step:966/2225 train_time:57926ms step_avg:59.96ms step:967/2225 train_time:57986ms step_avg:59.97ms step:968/2225 train_time:58046ms step_avg:59.96ms step:969/2225 train_time:58107ms step_avg:59.97ms step:970/2225 train_time:58166ms step_avg:59.96ms step:971/2225 train_time:58227ms step_avg:59.97ms step:972/2225 train_time:58286ms step_avg:59.96ms step:973/2225 train_time:58347ms step_avg:59.97ms step:974/2225 train_time:58406ms step_avg:59.96ms step:975/2225 train_time:58467ms step_avg:59.97ms step:976/2225 train_time:58526ms step_avg:59.97ms step:977/2225 train_time:58587ms step_avg:59.97ms step:978/2225 train_time:58646ms step_avg:59.97ms step:979/2225 train_time:58707ms step_avg:59.97ms step:980/2225 train_time:58766ms step_avg:59.97ms step:981/2225 train_time:58827ms step_avg:59.97ms step:982/2225 train_time:58886ms step_avg:59.97ms step:983/2225 train_time:58947ms step_avg:59.97ms step:984/2225 train_time:59007ms step_avg:59.97ms step:985/2225 train_time:59068ms step_avg:59.97ms step:986/2225 train_time:59127ms step_avg:59.97ms step:987/2225 train_time:59188ms step_avg:59.97ms step:988/2225 train_time:59247ms step_avg:59.97ms step:989/2225 train_time:59308ms step_avg:59.97ms step:990/2225 train_time:59367ms step_avg:59.97ms step:991/2225 train_time:59428ms step_avg:59.97ms step:992/2225 train_time:59488ms step_avg:59.97ms step:993/2225 train_time:59549ms step_avg:59.97ms step:994/2225 train_time:59608ms step_avg:59.97ms step:995/2225 train_time:59669ms step_avg:59.97ms step:996/2225 train_time:59728ms step_avg:59.97ms step:997/2225 train_time:59789ms step_avg:59.97ms step:998/2225 train_time:59848ms step_avg:59.97ms step:999/2225 train_time:59909ms step_avg:59.97ms step:1000/2225 train_time:59968ms step_avg:59.97ms step:1000/2225 val_loss:3.5945 train_time:60030ms step_avg:60.03ms step:1001/2225 train_time:60050ms step_avg:59.99ms step:1002/2225 train_time:60094ms step_avg:59.97ms step:1003/2225 train_time:60157ms step_avg:59.98ms step:1004/2225 train_time:60219ms step_avg:59.98ms step:1005/2225 train_time:60280ms step_avg:59.98ms step:1006/2225 train_time:60340ms step_avg:59.98ms step:1007/2225 train_time:60400ms step_avg:59.98ms step:1008/2225 train_time:60460ms step_avg:59.98ms step:1009/2225 train_time:60520ms step_avg:59.98ms step:1010/2225 train_time:60579ms step_avg:59.98ms step:1011/2225 train_time:60640ms step_avg:59.98ms step:1012/2225 train_time:60699ms step_avg:59.98ms step:1013/2225 train_time:60760ms step_avg:59.98ms step:1014/2225 train_time:60819ms step_avg:59.98ms step:1015/2225 train_time:60879ms step_avg:59.98ms step:1016/2225 train_time:60938ms step_avg:59.98ms step:1017/2225 train_time:61000ms step_avg:59.98ms step:1018/2225 train_time:61061ms step_avg:59.98ms step:1019/2225 train_time:61124ms step_avg:59.98ms step:1020/2225 train_time:61186ms step_avg:59.99ms step:1021/2225 train_time:61248ms step_avg:59.99ms step:1022/2225 train_time:61308ms step_avg:59.99ms step:1023/2225 train_time:61369ms step_avg:59.99ms step:1024/2225 train_time:61428ms step_avg:59.99ms step:1025/2225 train_time:61491ms step_avg:59.99ms step:1026/2225 train_time:61550ms step_avg:59.99ms step:1027/2225 train_time:61611ms step_avg:59.99ms step:1028/2225 train_time:61671ms step_avg:59.99ms step:1029/2225 train_time:61732ms step_avg:59.99ms step:1030/2225 train_time:61792ms step_avg:59.99ms step:1031/2225 train_time:61854ms step_avg:59.99ms step:1032/2225 train_time:61913ms step_avg:59.99ms step:1033/2225 train_time:61974ms step_avg:59.99ms step:1034/2225 train_time:62033ms step_avg:59.99ms step:1035/2225 train_time:62095ms step_avg:60.00ms step:1036/2225 train_time:62155ms step_avg:60.00ms step:1037/2225 train_time:62216ms step_avg:60.00ms step:1038/2225 train_time:62276ms step_avg:60.00ms step:1039/2225 train_time:62338ms step_avg:60.00ms step:1040/2225 train_time:62397ms step_avg:60.00ms step:1041/2225 train_time:62458ms step_avg:60.00ms step:1042/2225 train_time:62517ms step_avg:60.00ms step:1043/2225 train_time:62577ms step_avg:60.00ms step:1044/2225 train_time:62636ms step_avg:60.00ms step:1045/2225 train_time:62697ms step_avg:60.00ms step:1046/2225 train_time:62757ms step_avg:60.00ms step:1047/2225 train_time:62819ms step_avg:60.00ms step:1048/2225 train_time:62878ms step_avg:60.00ms step:1049/2225 train_time:62940ms step_avg:60.00ms step:1050/2225 train_time:63000ms step_avg:60.00ms step:1051/2225 train_time:63062ms step_avg:60.00ms step:1052/2225 train_time:63121ms step_avg:60.00ms step:1053/2225 train_time:63184ms step_avg:60.00ms step:1054/2225 train_time:63244ms step_avg:60.00ms step:1055/2225 train_time:63305ms step_avg:60.00ms step:1056/2225 train_time:63365ms step_avg:60.00ms step:1057/2225 train_time:63426ms step_avg:60.01ms step:1058/2225 train_time:63486ms step_avg:60.01ms step:1059/2225 train_time:63547ms step_avg:60.01ms step:1060/2225 train_time:63608ms step_avg:60.01ms step:1061/2225 train_time:63670ms step_avg:60.01ms step:1062/2225 train_time:63730ms step_avg:60.01ms step:1063/2225 train_time:63792ms step_avg:60.01ms step:1064/2225 train_time:63852ms step_avg:60.01ms step:1065/2225 train_time:63914ms step_avg:60.01ms step:1066/2225 train_time:63973ms step_avg:60.01ms step:1067/2225 train_time:64034ms step_avg:60.01ms step:1068/2225 train_time:64094ms step_avg:60.01ms step:1069/2225 train_time:64155ms step_avg:60.01ms step:1070/2225 train_time:64214ms step_avg:60.01ms step:1071/2225 train_time:64276ms step_avg:60.01ms step:1072/2225 train_time:64335ms step_avg:60.01ms step:1073/2225 train_time:64397ms step_avg:60.02ms step:1074/2225 train_time:64457ms step_avg:60.02ms step:1075/2225 train_time:64517ms step_avg:60.02ms step:1076/2225 train_time:64576ms step_avg:60.02ms step:1077/2225 train_time:64638ms step_avg:60.02ms step:1078/2225 train_time:64697ms step_avg:60.02ms step:1079/2225 train_time:64758ms step_avg:60.02ms step:1080/2225 train_time:64817ms step_avg:60.02ms step:1081/2225 train_time:64878ms step_avg:60.02ms step:1082/2225 train_time:64937ms step_avg:60.02ms step:1083/2225 train_time:64999ms step_avg:60.02ms step:1084/2225 train_time:65058ms step_avg:60.02ms step:1085/2225 train_time:65119ms step_avg:60.02ms step:1086/2225 train_time:65179ms step_avg:60.02ms step:1087/2225 train_time:65240ms step_avg:60.02ms step:1088/2225 train_time:65300ms step_avg:60.02ms step:1089/2225 train_time:65361ms step_avg:60.02ms step:1090/2225 train_time:65420ms step_avg:60.02ms step:1091/2225 train_time:65481ms step_avg:60.02ms step:1092/2225 train_time:65541ms step_avg:60.02ms step:1093/2225 train_time:65603ms step_avg:60.02ms step:1094/2225 train_time:65663ms step_avg:60.02ms step:1095/2225 train_time:65724ms step_avg:60.02ms step:1096/2225 train_time:65784ms step_avg:60.02ms step:1097/2225 train_time:65846ms step_avg:60.02ms step:1098/2225 train_time:65906ms step_avg:60.02ms step:1099/2225 train_time:65968ms step_avg:60.03ms step:1100/2225 train_time:66028ms step_avg:60.03ms step:1101/2225 train_time:66089ms step_avg:60.03ms step:1102/2225 train_time:66149ms step_avg:60.03ms step:1103/2225 train_time:66210ms step_avg:60.03ms step:1104/2225 train_time:66270ms step_avg:60.03ms step:1105/2225 train_time:66331ms step_avg:60.03ms step:1106/2225 train_time:66391ms step_avg:60.03ms step:1107/2225 train_time:66452ms step_avg:60.03ms step:1108/2225 train_time:66512ms step_avg:60.03ms step:1109/2225 train_time:66573ms step_avg:60.03ms step:1110/2225 train_time:66633ms step_avg:60.03ms step:1111/2225 train_time:66694ms step_avg:60.03ms step:1112/2225 train_time:66754ms step_avg:60.03ms step:1113/2225 train_time:66815ms step_avg:60.03ms step:1114/2225 train_time:66874ms step_avg:60.03ms step:1115/2225 train_time:66936ms step_avg:60.03ms step:1116/2225 train_time:66996ms step_avg:60.03ms step:1117/2225 train_time:67057ms step_avg:60.03ms step:1118/2225 train_time:67116ms step_avg:60.03ms step:1119/2225 train_time:67177ms step_avg:60.03ms step:1120/2225 train_time:67236ms step_avg:60.03ms step:1121/2225 train_time:67297ms step_avg:60.03ms step:1122/2225 train_time:67356ms step_avg:60.03ms step:1123/2225 train_time:67416ms step_avg:60.03ms step:1124/2225 train_time:67475ms step_avg:60.03ms step:1125/2225 train_time:67537ms step_avg:60.03ms step:1126/2225 train_time:67597ms step_avg:60.03ms step:1127/2225 train_time:67658ms step_avg:60.03ms step:1128/2225 train_time:67717ms step_avg:60.03ms step:1129/2225 train_time:67778ms step_avg:60.03ms step:1130/2225 train_time:67837ms step_avg:60.03ms step:1131/2225 train_time:67899ms step_avg:60.03ms step:1132/2225 train_time:67958ms step_avg:60.03ms step:1133/2225 train_time:68019ms step_avg:60.03ms step:1134/2225 train_time:68078ms step_avg:60.03ms step:1135/2225 train_time:68139ms step_avg:60.03ms step:1136/2225 train_time:68199ms step_avg:60.03ms step:1137/2225 train_time:68261ms step_avg:60.04ms step:1138/2225 train_time:68320ms step_avg:60.04ms step:1139/2225 train_time:68381ms step_avg:60.04ms step:1140/2225 train_time:68441ms step_avg:60.04ms step:1141/2225 train_time:68502ms step_avg:60.04ms step:1142/2225 train_time:68562ms step_avg:60.04ms step:1143/2225 train_time:68624ms step_avg:60.04ms step:1144/2225 train_time:68683ms step_avg:60.04ms step:1145/2225 train_time:68745ms step_avg:60.04ms step:1146/2225 train_time:68805ms step_avg:60.04ms step:1147/2225 train_time:68867ms step_avg:60.04ms step:1148/2225 train_time:68926ms step_avg:60.04ms step:1149/2225 train_time:68987ms step_avg:60.04ms step:1150/2225 train_time:69047ms step_avg:60.04ms step:1151/2225 train_time:69109ms step_avg:60.04ms step:1152/2225 train_time:69169ms step_avg:60.04ms step:1153/2225 train_time:69232ms step_avg:60.04ms step:1154/2225 train_time:69291ms step_avg:60.04ms step:1155/2225 train_time:69353ms step_avg:60.05ms step:1156/2225 train_time:69412ms step_avg:60.05ms step:1157/2225 train_time:69474ms step_avg:60.05ms step:1158/2225 train_time:69533ms step_avg:60.05ms step:1159/2225 train_time:69595ms step_avg:60.05ms step:1160/2225 train_time:69654ms step_avg:60.05ms step:1161/2225 train_time:69715ms step_avg:60.05ms step:1162/2225 train_time:69774ms step_avg:60.05ms step:1163/2225 train_time:69836ms step_avg:60.05ms step:1164/2225 train_time:69896ms step_avg:60.05ms step:1165/2225 train_time:69957ms step_avg:60.05ms step:1166/2225 train_time:70017ms step_avg:60.05ms step:1167/2225 train_time:70077ms step_avg:60.05ms step:1168/2225 train_time:70136ms step_avg:60.05ms step:1169/2225 train_time:70198ms step_avg:60.05ms step:1170/2225 train_time:70257ms step_avg:60.05ms step:1171/2225 train_time:70318ms step_avg:60.05ms step:1172/2225 train_time:70377ms step_avg:60.05ms step:1173/2225 train_time:70438ms step_avg:60.05ms step:1174/2225 train_time:70498ms step_avg:60.05ms step:1175/2225 train_time:70559ms step_avg:60.05ms step:1176/2225 train_time:70619ms step_avg:60.05ms step:1177/2225 train_time:70680ms step_avg:60.05ms step:1178/2225 train_time:70739ms step_avg:60.05ms step:1179/2225 train_time:70800ms step_avg:60.05ms step:1180/2225 train_time:70860ms step_avg:60.05ms step:1181/2225 train_time:70921ms step_avg:60.05ms step:1182/2225 train_time:70981ms step_avg:60.05ms step:1183/2225 train_time:71043ms step_avg:60.05ms step:1184/2225 train_time:71102ms step_avg:60.05ms step:1185/2225 train_time:71164ms step_avg:60.05ms step:1186/2225 train_time:71223ms step_avg:60.05ms step:1187/2225 train_time:71285ms step_avg:60.05ms step:1188/2225 train_time:71345ms step_avg:60.05ms step:1189/2225 train_time:71406ms step_avg:60.06ms step:1190/2225 train_time:71466ms step_avg:60.06ms step:1191/2225 train_time:71528ms step_avg:60.06ms step:1192/2225 train_time:71589ms step_avg:60.06ms step:1193/2225 train_time:71651ms step_avg:60.06ms step:1194/2225 train_time:71711ms step_avg:60.06ms step:1195/2225 train_time:71772ms step_avg:60.06ms step:1196/2225 train_time:71832ms step_avg:60.06ms step:1197/2225 train_time:71893ms step_avg:60.06ms step:1198/2225 train_time:71953ms step_avg:60.06ms step:1199/2225 train_time:72014ms step_avg:60.06ms step:1200/2225 train_time:72073ms step_avg:60.06ms step:1201/2225 train_time:72134ms step_avg:60.06ms step:1202/2225 train_time:72194ms step_avg:60.06ms step:1203/2225 train_time:72255ms step_avg:60.06ms step:1204/2225 train_time:72314ms step_avg:60.06ms step:1205/2225 train_time:72375ms step_avg:60.06ms step:1206/2225 train_time:72435ms step_avg:60.06ms step:1207/2225 train_time:72496ms step_avg:60.06ms step:1208/2225 train_time:72555ms step_avg:60.06ms step:1209/2225 train_time:72616ms step_avg:60.06ms step:1210/2225 train_time:72675ms step_avg:60.06ms step:1211/2225 train_time:72736ms step_avg:60.06ms step:1212/2225 train_time:72795ms step_avg:60.06ms step:1213/2225 train_time:72856ms step_avg:60.06ms step:1214/2225 train_time:72915ms step_avg:60.06ms step:1215/2225 train_time:72976ms step_avg:60.06ms step:1216/2225 train_time:73036ms step_avg:60.06ms step:1217/2225 train_time:73097ms step_avg:60.06ms step:1218/2225 train_time:73156ms step_avg:60.06ms step:1219/2225 train_time:73217ms step_avg:60.06ms step:1220/2225 train_time:73276ms step_avg:60.06ms step:1221/2225 train_time:73337ms step_avg:60.06ms step:1222/2225 train_time:73397ms step_avg:60.06ms step:1223/2225 train_time:73458ms step_avg:60.06ms step:1224/2225 train_time:73517ms step_avg:60.06ms step:1225/2225 train_time:73578ms step_avg:60.06ms step:1226/2225 train_time:73637ms step_avg:60.06ms step:1227/2225 train_time:73698ms step_avg:60.06ms step:1228/2225 train_time:73758ms step_avg:60.06ms step:1229/2225 train_time:73819ms step_avg:60.06ms step:1230/2225 train_time:73879ms step_avg:60.06ms step:1231/2225 train_time:73940ms step_avg:60.06ms step:1232/2225 train_time:74583ms step_avg:60.54ms step:1233/2225 train_time:74643ms step_avg:60.54ms step:1234/2225 train_time:74701ms step_avg:60.54ms step:1235/2225 train_time:74761ms step_avg:60.53ms step:1236/2225 train_time:74819ms step_avg:60.53ms step:1237/2225 train_time:74879ms step_avg:60.53ms step:1238/2225 train_time:74937ms step_avg:60.53ms step:1239/2225 train_time:74998ms step_avg:60.53ms step:1240/2225 train_time:75056ms step_avg:60.53ms step:1241/2225 train_time:75116ms step_avg:60.53ms step:1242/2225 train_time:75174ms step_avg:60.53ms step:1243/2225 train_time:75234ms step_avg:60.53ms step:1244/2225 train_time:75293ms step_avg:60.52ms step:1245/2225 train_time:75353ms step_avg:60.52ms step:1246/2225 train_time:75411ms step_avg:60.52ms step:1247/2225 train_time:75474ms step_avg:60.52ms step:1248/2225 train_time:75537ms step_avg:60.53ms step:1249/2225 train_time:75602ms step_avg:60.53ms step:1250/2225 train_time:75661ms step_avg:60.53ms step:1250/2225 val_loss:3.5201 train_time:75723ms step_avg:60.58ms step:1251/2225 train_time:75742ms step_avg:60.55ms step:1252/2225 train_time:75783ms step_avg:60.53ms step:1253/2225 train_time:75847ms step_avg:60.53ms step:1254/2225 train_time:75911ms step_avg:60.54ms step:1255/2225 train_time:75974ms step_avg:60.54ms step:1256/2225 train_time:76034ms step_avg:60.54ms step:1257/2225 train_time:76096ms step_avg:60.54ms step:1258/2225 train_time:76155ms step_avg:60.54ms step:1259/2225 train_time:76216ms step_avg:60.54ms step:1260/2225 train_time:76275ms step_avg:60.54ms step:1261/2225 train_time:76336ms step_avg:60.54ms step:1262/2225 train_time:76394ms step_avg:60.53ms step:1263/2225 train_time:76455ms step_avg:60.53ms step:1264/2225 train_time:76514ms step_avg:60.53ms step:1265/2225 train_time:76575ms step_avg:60.53ms step:1266/2225 train_time:76635ms step_avg:60.53ms step:1267/2225 train_time:76698ms step_avg:60.54ms step:1268/2225 train_time:76759ms step_avg:60.54ms step:1269/2225 train_time:76824ms step_avg:60.54ms step:1270/2225 train_time:76884ms step_avg:60.54ms step:1271/2225 train_time:76947ms step_avg:60.54ms step:1272/2225 train_time:77007ms step_avg:60.54ms step:1273/2225 train_time:77068ms step_avg:60.54ms step:1274/2225 train_time:77127ms step_avg:60.54ms step:1275/2225 train_time:77188ms step_avg:60.54ms step:1276/2225 train_time:77247ms step_avg:60.54ms step:1277/2225 train_time:77307ms step_avg:60.54ms step:1278/2225 train_time:77367ms step_avg:60.54ms step:1279/2225 train_time:77427ms step_avg:60.54ms step:1280/2225 train_time:77486ms step_avg:60.54ms step:1281/2225 train_time:77547ms step_avg:60.54ms step:1282/2225 train_time:77607ms step_avg:60.54ms step:1283/2225 train_time:77668ms step_avg:60.54ms step:1284/2225 train_time:77728ms step_avg:60.54ms step:1285/2225 train_time:77790ms step_avg:60.54ms step:1286/2225 train_time:77850ms step_avg:60.54ms step:1287/2225 train_time:77912ms step_avg:60.54ms step:1288/2225 train_time:77971ms step_avg:60.54ms step:1289/2225 train_time:78033ms step_avg:60.54ms step:1290/2225 train_time:78092ms step_avg:60.54ms step:1291/2225 train_time:78154ms step_avg:60.54ms step:1292/2225 train_time:78213ms step_avg:60.54ms step:1293/2225 train_time:78274ms step_avg:60.54ms step:1294/2225 train_time:78334ms step_avg:60.54ms step:1295/2225 train_time:78396ms step_avg:60.54ms step:1296/2225 train_time:78455ms step_avg:60.54ms step:1297/2225 train_time:78517ms step_avg:60.54ms step:1298/2225 train_time:78577ms step_avg:60.54ms step:1299/2225 train_time:78638ms step_avg:60.54ms step:1300/2225 train_time:78698ms step_avg:60.54ms step:1301/2225 train_time:78760ms step_avg:60.54ms step:1302/2225 train_time:78820ms step_avg:60.54ms step:1303/2225 train_time:78882ms step_avg:60.54ms step:1304/2225 train_time:78942ms step_avg:60.54ms step:1305/2225 train_time:79003ms step_avg:60.54ms step:1306/2225 train_time:79063ms step_avg:60.54ms step:1307/2225 train_time:79124ms step_avg:60.54ms step:1308/2225 train_time:79183ms step_avg:60.54ms step:1309/2225 train_time:79244ms step_avg:60.54ms step:1310/2225 train_time:79303ms step_avg:60.54ms step:1311/2225 train_time:79364ms step_avg:60.54ms step:1312/2225 train_time:79424ms step_avg:60.54ms step:1313/2225 train_time:79484ms step_avg:60.54ms step:1314/2225 train_time:79543ms step_avg:60.54ms step:1315/2225 train_time:79605ms step_avg:60.54ms step:1316/2225 train_time:79664ms step_avg:60.53ms step:1317/2225 train_time:79725ms step_avg:60.54ms step:1318/2225 train_time:79784ms step_avg:60.53ms step:1319/2225 train_time:79845ms step_avg:60.53ms step:1320/2225 train_time:79904ms step_avg:60.53ms step:1321/2225 train_time:79966ms step_avg:60.53ms step:1322/2225 train_time:80025ms step_avg:60.53ms step:1323/2225 train_time:80086ms step_avg:60.53ms step:1324/2225 train_time:80145ms step_avg:60.53ms step:1325/2225 train_time:80206ms step_avg:60.53ms step:1326/2225 train_time:80265ms step_avg:60.53ms step:1327/2225 train_time:80327ms step_avg:60.53ms step:1328/2225 train_time:80386ms step_avg:60.53ms step:1329/2225 train_time:80447ms step_avg:60.53ms step:1330/2225 train_time:80506ms step_avg:60.53ms step:1331/2225 train_time:80567ms step_avg:60.53ms step:1332/2225 train_time:80626ms step_avg:60.53ms step:1333/2225 train_time:80687ms step_avg:60.53ms step:1334/2225 train_time:80746ms step_avg:60.53ms step:1335/2225 train_time:80807ms step_avg:60.53ms step:1336/2225 train_time:80866ms step_avg:60.53ms step:1337/2225 train_time:80927ms step_avg:60.53ms step:1338/2225 train_time:80986ms step_avg:60.53ms step:1339/2225 train_time:81047ms step_avg:60.53ms step:1340/2225 train_time:81106ms step_avg:60.53ms step:1341/2225 train_time:81167ms step_avg:60.53ms step:1342/2225 train_time:81226ms step_avg:60.53ms step:1343/2225 train_time:81286ms step_avg:60.53ms step:1344/2225 train_time:81345ms step_avg:60.52ms step:1345/2225 train_time:81406ms step_avg:60.52ms step:1346/2225 train_time:81465ms step_avg:60.52ms step:1347/2225 train_time:81526ms step_avg:60.52ms step:1348/2225 train_time:81585ms step_avg:60.52ms step:1349/2225 train_time:81646ms step_avg:60.52ms step:1350/2225 train_time:81705ms step_avg:60.52ms step:1351/2225 train_time:81766ms step_avg:60.52ms step:1352/2225 train_time:81825ms step_avg:60.52ms step:1353/2225 train_time:81886ms step_avg:60.52ms step:1354/2225 train_time:81945ms step_avg:60.52ms step:1355/2225 train_time:82006ms step_avg:60.52ms step:1356/2225 train_time:82065ms step_avg:60.52ms step:1357/2225 train_time:82127ms step_avg:60.52ms step:1358/2225 train_time:82185ms step_avg:60.52ms step:1359/2225 train_time:82246ms step_avg:60.52ms step:1360/2225 train_time:82305ms step_avg:60.52ms step:1361/2225 train_time:82366ms step_avg:60.52ms step:1362/2225 train_time:82425ms step_avg:60.52ms step:1363/2225 train_time:82486ms step_avg:60.52ms step:1364/2225 train_time:82545ms step_avg:60.52ms step:1365/2225 train_time:82606ms step_avg:60.52ms step:1366/2225 train_time:82665ms step_avg:60.52ms step:1367/2225 train_time:82726ms step_avg:60.52ms step:1368/2225 train_time:82785ms step_avg:60.52ms step:1369/2225 train_time:82846ms step_avg:60.52ms step:1370/2225 train_time:82905ms step_avg:60.51ms step:1371/2225 train_time:82966ms step_avg:60.52ms step:1372/2225 train_time:83025ms step_avg:60.51ms step:1373/2225 train_time:83086ms step_avg:60.51ms step:1374/2225 train_time:83145ms step_avg:60.51ms step:1375/2225 train_time:83207ms step_avg:60.51ms step:1376/2225 train_time:83266ms step_avg:60.51ms step:1377/2225 train_time:83327ms step_avg:60.51ms step:1378/2225 train_time:83385ms step_avg:60.51ms step:1379/2225 train_time:83446ms step_avg:60.51ms step:1380/2225 train_time:83505ms step_avg:60.51ms step:1381/2225 train_time:83566ms step_avg:60.51ms step:1382/2225 train_time:83625ms step_avg:60.51ms step:1383/2225 train_time:83686ms step_avg:60.51ms step:1384/2225 train_time:83745ms step_avg:60.51ms step:1385/2225 train_time:83807ms step_avg:60.51ms step:1386/2225 train_time:83866ms step_avg:60.51ms step:1387/2225 train_time:83927ms step_avg:60.51ms step:1388/2225 train_time:83986ms step_avg:60.51ms step:1389/2225 train_time:84047ms step_avg:60.51ms step:1390/2225 train_time:84106ms step_avg:60.51ms step:1391/2225 train_time:84167ms step_avg:60.51ms step:1392/2225 train_time:84226ms step_avg:60.51ms step:1393/2225 train_time:84287ms step_avg:60.51ms step:1394/2225 train_time:84346ms step_avg:60.51ms step:1395/2225 train_time:84407ms step_avg:60.51ms step:1396/2225 train_time:84466ms step_avg:60.51ms step:1397/2225 train_time:84527ms step_avg:60.51ms step:1398/2225 train_time:84585ms step_avg:60.50ms step:1399/2225 train_time:84646ms step_avg:60.50ms step:1400/2225 train_time:84705ms step_avg:60.50ms step:1401/2225 train_time:84766ms step_avg:60.50ms step:1402/2225 train_time:84825ms step_avg:60.50ms step:1403/2225 train_time:84886ms step_avg:60.50ms step:1404/2225 train_time:84945ms step_avg:60.50ms step:1405/2225 train_time:85006ms step_avg:60.50ms step:1406/2225 train_time:85065ms step_avg:60.50ms step:1407/2225 train_time:85126ms step_avg:60.50ms step:1408/2225 train_time:85185ms step_avg:60.50ms step:1409/2225 train_time:85246ms step_avg:60.50ms step:1410/2225 train_time:85306ms step_avg:60.50ms step:1411/2225 train_time:85367ms step_avg:60.50ms step:1412/2225 train_time:85427ms step_avg:60.50ms step:1413/2225 train_time:85488ms step_avg:60.50ms step:1414/2225 train_time:85546ms step_avg:60.50ms step:1415/2225 train_time:85608ms step_avg:60.50ms step:1416/2225 train_time:85667ms step_avg:60.50ms step:1417/2225 train_time:85728ms step_avg:60.50ms step:1418/2225 train_time:85787ms step_avg:60.50ms step:1419/2225 train_time:85848ms step_avg:60.50ms step:1420/2225 train_time:85907ms step_avg:60.50ms step:1421/2225 train_time:85968ms step_avg:60.50ms step:1422/2225 train_time:86028ms step_avg:60.50ms step:1423/2225 train_time:86089ms step_avg:60.50ms step:1424/2225 train_time:86148ms step_avg:60.50ms step:1425/2225 train_time:86209ms step_avg:60.50ms step:1426/2225 train_time:86268ms step_avg:60.50ms step:1427/2225 train_time:86330ms step_avg:60.50ms step:1428/2225 train_time:86389ms step_avg:60.50ms step:1429/2225 train_time:86450ms step_avg:60.50ms step:1430/2225 train_time:86509ms step_avg:60.50ms step:1431/2225 train_time:86570ms step_avg:60.50ms step:1432/2225 train_time:86630ms step_avg:60.50ms step:1433/2225 train_time:86691ms step_avg:60.50ms step:1434/2225 train_time:86750ms step_avg:60.49ms step:1435/2225 train_time:86811ms step_avg:60.50ms step:1436/2225 train_time:86871ms step_avg:60.49ms step:1437/2225 train_time:86932ms step_avg:60.50ms step:1438/2225 train_time:86992ms step_avg:60.49ms step:1439/2225 train_time:87053ms step_avg:60.50ms step:1440/2225 train_time:87113ms step_avg:60.49ms step:1441/2225 train_time:87174ms step_avg:60.50ms step:1442/2225 train_time:87234ms step_avg:60.50ms step:1443/2225 train_time:87295ms step_avg:60.50ms step:1444/2225 train_time:87355ms step_avg:60.50ms step:1445/2225 train_time:87418ms step_avg:60.50ms step:1446/2225 train_time:87478ms step_avg:60.50ms step:1447/2225 train_time:87539ms step_avg:60.50ms step:1448/2225 train_time:87599ms step_avg:60.50ms step:1449/2225 train_time:87661ms step_avg:60.50ms step:1450/2225 train_time:87721ms step_avg:60.50ms step:1451/2225 train_time:87783ms step_avg:60.50ms step:1452/2225 train_time:87842ms step_avg:60.50ms step:1453/2225 train_time:87904ms step_avg:60.50ms step:1454/2225 train_time:87963ms step_avg:60.50ms step:1455/2225 train_time:88025ms step_avg:60.50ms step:1456/2225 train_time:88084ms step_avg:60.50ms step:1457/2225 train_time:88145ms step_avg:60.50ms step:1458/2225 train_time:88205ms step_avg:60.50ms step:1459/2225 train_time:88266ms step_avg:60.50ms step:1460/2225 train_time:88327ms step_avg:60.50ms step:1461/2225 train_time:88388ms step_avg:60.50ms step:1462/2225 train_time:88448ms step_avg:60.50ms step:1463/2225 train_time:88510ms step_avg:60.50ms step:1464/2225 train_time:88570ms step_avg:60.50ms step:1465/2225 train_time:88631ms step_avg:60.50ms step:1466/2225 train_time:88691ms step_avg:60.50ms step:1467/2225 train_time:88754ms step_avg:60.50ms step:1468/2225 train_time:88814ms step_avg:60.50ms step:1469/2225 train_time:88876ms step_avg:60.50ms step:1470/2225 train_time:88936ms step_avg:60.50ms step:1471/2225 train_time:88998ms step_avg:60.50ms step:1472/2225 train_time:89058ms step_avg:60.50ms step:1473/2225 train_time:89121ms step_avg:60.50ms step:1474/2225 train_time:89181ms step_avg:60.50ms step:1475/2225 train_time:89244ms step_avg:60.50ms step:1476/2225 train_time:89303ms step_avg:60.50ms step:1477/2225 train_time:89365ms step_avg:60.50ms step:1478/2225 train_time:89425ms step_avg:60.50ms step:1479/2225 train_time:89486ms step_avg:60.50ms step:1480/2225 train_time:89546ms step_avg:60.50ms step:1481/2225 train_time:89607ms step_avg:60.50ms step:1482/2225 train_time:89667ms step_avg:60.50ms step:1483/2225 train_time:89728ms step_avg:60.50ms step:1484/2225 train_time:89788ms step_avg:60.50ms step:1485/2225 train_time:89849ms step_avg:60.50ms step:1486/2225 train_time:89909ms step_avg:60.50ms step:1487/2225 train_time:89972ms step_avg:60.51ms step:1488/2225 train_time:90032ms step_avg:60.51ms step:1489/2225 train_time:90094ms step_avg:60.51ms step:1490/2225 train_time:90154ms step_avg:60.51ms step:1491/2225 train_time:90216ms step_avg:60.51ms step:1492/2225 train_time:90276ms step_avg:60.51ms step:1493/2225 train_time:90339ms step_avg:60.51ms step:1494/2225 train_time:90399ms step_avg:60.51ms step:1495/2225 train_time:90462ms step_avg:60.51ms step:1496/2225 train_time:90522ms step_avg:60.51ms step:1497/2225 train_time:90584ms step_avg:60.51ms step:1498/2225 train_time:90644ms step_avg:60.51ms step:1499/2225 train_time:90706ms step_avg:60.51ms step:1500/2225 train_time:90766ms step_avg:60.51ms step:1500/2225 val_loss:3.4392 train_time:90828ms step_avg:60.55ms step:1501/2225 train_time:90851ms step_avg:60.53ms step:1502/2225 train_time:90889ms step_avg:60.51ms step:1503/2225 train_time:90958ms step_avg:60.52ms step:1504/2225 train_time:91020ms step_avg:60.52ms step:1505/2225 train_time:91081ms step_avg:60.52ms step:1506/2225 train_time:91141ms step_avg:60.52ms step:1507/2225 train_time:91202ms step_avg:60.52ms step:1508/2225 train_time:91261ms step_avg:60.52ms step:1509/2225 train_time:91321ms step_avg:60.52ms step:1510/2225 train_time:91380ms step_avg:60.52ms step:1511/2225 train_time:91441ms step_avg:60.52ms step:1512/2225 train_time:91500ms step_avg:60.52ms step:1513/2225 train_time:91560ms step_avg:60.52ms step:1514/2225 train_time:91619ms step_avg:60.51ms step:1515/2225 train_time:91680ms step_avg:60.52ms step:1516/2225 train_time:91741ms step_avg:60.52ms step:1517/2225 train_time:91805ms step_avg:60.52ms step:1518/2225 train_time:91865ms step_avg:60.52ms step:1519/2225 train_time:91928ms step_avg:60.52ms step:1520/2225 train_time:91989ms step_avg:60.52ms step:1521/2225 train_time:92051ms step_avg:60.52ms step:1522/2225 train_time:92111ms step_avg:60.52ms step:1523/2225 train_time:92173ms step_avg:60.52ms step:1524/2225 train_time:92233ms step_avg:60.52ms step:1525/2225 train_time:92295ms step_avg:60.52ms step:1526/2225 train_time:92355ms step_avg:60.52ms step:1527/2225 train_time:92416ms step_avg:60.52ms step:1528/2225 train_time:92475ms step_avg:60.52ms step:1529/2225 train_time:92537ms step_avg:60.52ms step:1530/2225 train_time:92597ms step_avg:60.52ms step:1531/2225 train_time:92658ms step_avg:60.52ms step:1532/2225 train_time:92718ms step_avg:60.52ms step:1533/2225 train_time:92781ms step_avg:60.52ms step:1534/2225 train_time:92841ms step_avg:60.52ms step:1535/2225 train_time:92904ms step_avg:60.52ms step:1536/2225 train_time:92963ms step_avg:60.52ms step:1537/2225 train_time:93025ms step_avg:60.52ms step:1538/2225 train_time:93085ms step_avg:60.52ms step:1539/2225 train_time:93147ms step_avg:60.52ms step:1540/2225 train_time:93207ms step_avg:60.52ms step:1541/2225 train_time:93268ms step_avg:60.52ms step:1542/2225 train_time:93328ms step_avg:60.52ms step:1543/2225 train_time:93390ms step_avg:60.53ms step:1544/2225 train_time:93450ms step_avg:60.52ms step:1545/2225 train_time:93512ms step_avg:60.53ms step:1546/2225 train_time:93572ms step_avg:60.53ms step:1547/2225 train_time:93635ms step_avg:60.53ms step:1548/2225 train_time:93695ms step_avg:60.53ms step:1549/2225 train_time:93758ms step_avg:60.53ms step:1550/2225 train_time:93818ms step_avg:60.53ms step:1551/2225 train_time:93880ms step_avg:60.53ms step:1552/2225 train_time:93940ms step_avg:60.53ms step:1553/2225 train_time:94002ms step_avg:60.53ms step:1554/2225 train_time:94061ms step_avg:60.53ms step:1555/2225 train_time:94122ms step_avg:60.53ms step:1556/2225 train_time:94182ms step_avg:60.53ms step:1557/2225 train_time:94243ms step_avg:60.53ms step:1558/2225 train_time:94302ms step_avg:60.53ms step:1559/2225 train_time:94363ms step_avg:60.53ms step:1560/2225 train_time:94423ms step_avg:60.53ms step:1561/2225 train_time:94485ms step_avg:60.53ms step:1562/2225 train_time:94545ms step_avg:60.53ms step:1563/2225 train_time:94607ms step_avg:60.53ms step:1564/2225 train_time:94668ms step_avg:60.53ms step:1565/2225 train_time:94730ms step_avg:60.53ms step:1566/2225 train_time:94790ms step_avg:60.53ms step:1567/2225 train_time:94852ms step_avg:60.53ms step:1568/2225 train_time:94912ms step_avg:60.53ms step:1569/2225 train_time:94975ms step_avg:60.53ms step:1570/2225 train_time:95035ms step_avg:60.53ms step:1571/2225 train_time:95097ms step_avg:60.53ms step:1572/2225 train_time:95157ms step_avg:60.53ms step:1573/2225 train_time:95219ms step_avg:60.53ms step:1574/2225 train_time:95278ms step_avg:60.53ms step:1575/2225 train_time:95340ms step_avg:60.53ms step:1576/2225 train_time:95400ms step_avg:60.53ms step:1577/2225 train_time:95461ms step_avg:60.53ms step:1578/2225 train_time:95521ms step_avg:60.53ms step:1579/2225 train_time:95582ms step_avg:60.53ms step:1580/2225 train_time:95642ms step_avg:60.53ms step:1581/2225 train_time:95703ms step_avg:60.53ms step:1582/2225 train_time:95763ms step_avg:60.53ms step:1583/2225 train_time:95825ms step_avg:60.53ms step:1584/2225 train_time:95886ms step_avg:60.53ms step:1585/2225 train_time:95948ms step_avg:60.54ms step:1586/2225 train_time:96009ms step_avg:60.54ms step:1587/2225 train_time:96071ms step_avg:60.54ms step:1588/2225 train_time:96131ms step_avg:60.54ms step:1589/2225 train_time:96194ms step_avg:60.54ms step:1590/2225 train_time:96254ms step_avg:60.54ms step:1591/2225 train_time:96316ms step_avg:60.54ms step:1592/2225 train_time:96377ms step_avg:60.54ms step:1593/2225 train_time:96439ms step_avg:60.54ms step:1594/2225 train_time:96499ms step_avg:60.54ms step:1595/2225 train_time:96560ms step_avg:60.54ms step:1596/2225 train_time:96619ms step_avg:60.54ms step:1597/2225 train_time:96681ms step_avg:60.54ms step:1598/2225 train_time:96741ms step_avg:60.54ms step:1599/2225 train_time:96803ms step_avg:60.54ms step:1600/2225 train_time:96863ms step_avg:60.54ms step:1601/2225 train_time:96924ms step_avg:60.54ms step:1602/2225 train_time:96984ms step_avg:60.54ms step:1603/2225 train_time:97046ms step_avg:60.54ms step:1604/2225 train_time:97106ms step_avg:60.54ms step:1605/2225 train_time:97168ms step_avg:60.54ms step:1606/2225 train_time:97229ms step_avg:60.54ms step:1607/2225 train_time:97291ms step_avg:60.54ms step:1608/2225 train_time:97351ms step_avg:60.54ms step:1609/2225 train_time:97413ms step_avg:60.54ms step:1610/2225 train_time:97473ms step_avg:60.54ms step:1611/2225 train_time:97536ms step_avg:60.54ms step:1612/2225 train_time:97596ms step_avg:60.54ms step:1613/2225 train_time:97658ms step_avg:60.54ms step:1614/2225 train_time:97717ms step_avg:60.54ms step:1615/2225 train_time:97779ms step_avg:60.54ms step:1616/2225 train_time:97838ms step_avg:60.54ms step:1617/2225 train_time:97900ms step_avg:60.54ms step:1618/2225 train_time:97960ms step_avg:60.54ms step:1619/2225 train_time:98021ms step_avg:60.54ms step:1620/2225 train_time:98081ms step_avg:60.54ms step:1621/2225 train_time:98142ms step_avg:60.54ms step:1622/2225 train_time:98202ms step_avg:60.54ms step:1623/2225 train_time:98263ms step_avg:60.54ms step:1624/2225 train_time:98323ms step_avg:60.54ms step:1625/2225 train_time:98385ms step_avg:60.54ms step:1626/2225 train_time:98446ms step_avg:60.54ms step:1627/2225 train_time:98509ms step_avg:60.55ms step:1628/2225 train_time:98570ms step_avg:60.55ms step:1629/2225 train_time:98632ms step_avg:60.55ms step:1630/2225 train_time:98692ms step_avg:60.55ms step:1631/2225 train_time:98754ms step_avg:60.55ms step:1632/2225 train_time:98815ms step_avg:60.55ms step:1633/2225 train_time:98877ms step_avg:60.55ms step:1634/2225 train_time:98937ms step_avg:60.55ms step:1635/2225 train_time:98998ms step_avg:60.55ms step:1636/2225 train_time:99058ms step_avg:60.55ms step:1637/2225 train_time:99119ms step_avg:60.55ms step:1638/2225 train_time:99179ms step_avg:60.55ms step:1639/2225 train_time:99240ms step_avg:60.55ms step:1640/2225 train_time:99300ms step_avg:60.55ms step:1641/2225 train_time:99362ms step_avg:60.55ms step:1642/2225 train_time:99422ms step_avg:60.55ms step:1643/2225 train_time:99484ms step_avg:60.55ms step:1644/2225 train_time:99544ms step_avg:60.55ms step:1645/2225 train_time:99606ms step_avg:60.55ms step:1646/2225 train_time:99665ms step_avg:60.55ms step:1647/2225 train_time:99728ms step_avg:60.55ms step:1648/2225 train_time:99788ms step_avg:60.55ms step:1649/2225 train_time:99850ms step_avg:60.55ms step:1650/2225 train_time:99910ms step_avg:60.55ms step:1651/2225 train_time:99973ms step_avg:60.55ms step:1652/2225 train_time:100033ms step_avg:60.55ms step:1653/2225 train_time:100095ms step_avg:60.55ms step:1654/2225 train_time:100156ms step_avg:60.55ms step:1655/2225 train_time:100218ms step_avg:60.55ms step:1656/2225 train_time:100277ms step_avg:60.55ms step:1657/2225 train_time:100338ms step_avg:60.55ms step:1658/2225 train_time:100398ms step_avg:60.55ms step:1659/2225 train_time:100460ms step_avg:60.55ms step:1660/2225 train_time:100520ms step_avg:60.55ms step:1661/2225 train_time:100582ms step_avg:60.56ms step:1662/2225 train_time:100642ms step_avg:60.55ms step:1663/2225 train_time:100703ms step_avg:60.56ms step:1664/2225 train_time:100762ms step_avg:60.55ms step:1665/2225 train_time:100824ms step_avg:60.55ms step:1666/2225 train_time:100883ms step_avg:60.55ms step:1667/2225 train_time:100945ms step_avg:60.55ms step:1668/2225 train_time:101005ms step_avg:60.55ms step:1669/2225 train_time:101068ms step_avg:60.56ms step:1670/2225 train_time:101128ms step_avg:60.56ms step:1671/2225 train_time:101190ms step_avg:60.56ms step:1672/2225 train_time:101250ms step_avg:60.56ms step:1673/2225 train_time:101312ms step_avg:60.56ms step:1674/2225 train_time:101373ms step_avg:60.56ms step:1675/2225 train_time:101435ms step_avg:60.56ms step:1676/2225 train_time:101495ms step_avg:60.56ms step:1677/2225 train_time:101557ms step_avg:60.56ms step:1678/2225 train_time:101617ms step_avg:60.56ms step:1679/2225 train_time:101679ms step_avg:60.56ms step:1680/2225 train_time:101739ms step_avg:60.56ms step:1681/2225 train_time:101801ms step_avg:60.56ms step:1682/2225 train_time:101860ms step_avg:60.56ms step:1683/2225 train_time:101922ms step_avg:60.56ms step:1684/2225 train_time:101982ms step_avg:60.56ms step:1685/2225 train_time:102043ms step_avg:60.56ms step:1686/2225 train_time:102103ms step_avg:60.56ms step:1687/2225 train_time:102165ms step_avg:60.56ms step:1688/2225 train_time:102225ms step_avg:60.56ms step:1689/2225 train_time:102287ms step_avg:60.56ms step:1690/2225 train_time:102347ms step_avg:60.56ms step:1691/2225 train_time:102409ms step_avg:60.56ms step:1692/2225 train_time:102469ms step_avg:60.56ms step:1693/2225 train_time:102531ms step_avg:60.56ms step:1694/2225 train_time:102593ms step_avg:60.56ms step:1695/2225 train_time:102656ms step_avg:60.56ms step:1696/2225 train_time:102716ms step_avg:60.56ms step:1697/2225 train_time:102778ms step_avg:60.56ms step:1698/2225 train_time:102837ms step_avg:60.56ms step:1699/2225 train_time:102900ms step_avg:60.56ms step:1700/2225 train_time:102959ms step_avg:60.56ms step:1701/2225 train_time:103021ms step_avg:60.56ms step:1702/2225 train_time:103080ms step_avg:60.56ms step:1703/2225 train_time:103142ms step_avg:60.56ms step:1704/2225 train_time:103201ms step_avg:60.56ms step:1705/2225 train_time:103263ms step_avg:60.56ms step:1706/2225 train_time:103322ms step_avg:60.56ms step:1707/2225 train_time:103384ms step_avg:60.56ms step:1708/2225 train_time:103444ms step_avg:60.56ms step:1709/2225 train_time:103505ms step_avg:60.56ms step:1710/2225 train_time:103565ms step_avg:60.56ms step:1711/2225 train_time:103628ms step_avg:60.57ms step:1712/2225 train_time:103689ms step_avg:60.57ms step:1713/2225 train_time:103751ms step_avg:60.57ms step:1714/2225 train_time:103812ms step_avg:60.57ms step:1715/2225 train_time:103874ms step_avg:60.57ms step:1716/2225 train_time:103934ms step_avg:60.57ms step:1717/2225 train_time:103996ms step_avg:60.57ms step:1718/2225 train_time:104056ms step_avg:60.57ms step:1719/2225 train_time:104118ms step_avg:60.57ms step:1720/2225 train_time:104177ms step_avg:60.57ms step:1721/2225 train_time:104238ms step_avg:60.57ms step:1722/2225 train_time:104299ms step_avg:60.57ms step:1723/2225 train_time:104360ms step_avg:60.57ms step:1724/2225 train_time:104420ms step_avg:60.57ms step:1725/2225 train_time:104482ms step_avg:60.57ms step:1726/2225 train_time:104542ms step_avg:60.57ms step:1727/2225 train_time:104604ms step_avg:60.57ms step:1728/2225 train_time:104664ms step_avg:60.57ms step:1729/2225 train_time:104726ms step_avg:60.57ms step:1730/2225 train_time:104786ms step_avg:60.57ms step:1731/2225 train_time:104848ms step_avg:60.57ms step:1732/2225 train_time:104909ms step_avg:60.57ms step:1733/2225 train_time:104971ms step_avg:60.57ms step:1734/2225 train_time:105031ms step_avg:60.57ms step:1735/2225 train_time:105094ms step_avg:60.57ms step:1736/2225 train_time:105154ms step_avg:60.57ms step:1737/2225 train_time:105216ms step_avg:60.57ms step:1738/2225 train_time:105277ms step_avg:60.57ms step:1739/2225 train_time:105338ms step_avg:60.57ms step:1740/2225 train_time:105398ms step_avg:60.57ms step:1741/2225 train_time:105460ms step_avg:60.57ms step:1742/2225 train_time:105519ms step_avg:60.57ms step:1743/2225 train_time:105581ms step_avg:60.57ms step:1744/2225 train_time:105641ms step_avg:60.57ms step:1745/2225 train_time:105702ms step_avg:60.57ms step:1746/2225 train_time:105762ms step_avg:60.57ms step:1747/2225 train_time:105824ms step_avg:60.57ms step:1748/2225 train_time:105883ms step_avg:60.57ms step:1749/2225 train_time:105945ms step_avg:60.57ms step:1750/2225 train_time:106005ms step_avg:60.57ms step:1750/2225 val_loss:3.3744 train_time:106068ms step_avg:60.61ms step:1751/2225 train_time:106088ms step_avg:60.59ms step:1752/2225 train_time:106129ms step_avg:60.58ms step:1753/2225 train_time:106195ms step_avg:60.58ms step:1754/2225 train_time:106259ms step_avg:60.58ms step:1755/2225 train_time:106322ms step_avg:60.58ms step:1756/2225 train_time:106383ms step_avg:60.58ms step:1757/2225 train_time:106445ms step_avg:60.58ms step:1758/2225 train_time:106504ms step_avg:60.58ms step:1759/2225 train_time:106566ms step_avg:60.58ms step:1760/2225 train_time:106626ms step_avg:60.58ms step:1761/2225 train_time:106687ms step_avg:60.58ms step:1762/2225 train_time:106746ms step_avg:60.58ms step:1763/2225 train_time:106807ms step_avg:60.58ms step:1764/2225 train_time:106867ms step_avg:60.58ms step:1765/2225 train_time:106928ms step_avg:60.58ms step:1766/2225 train_time:106988ms step_avg:60.58ms step:1767/2225 train_time:107050ms step_avg:60.58ms step:1768/2225 train_time:107111ms step_avg:60.58ms step:1769/2225 train_time:107174ms step_avg:60.58ms step:1770/2225 train_time:107236ms step_avg:60.59ms step:1771/2225 train_time:107299ms step_avg:60.59ms step:1772/2225 train_time:107360ms step_avg:60.59ms step:1773/2225 train_time:107422ms step_avg:60.59ms step:1774/2225 train_time:107481ms step_avg:60.59ms step:1775/2225 train_time:107543ms step_avg:60.59ms step:1776/2225 train_time:107603ms step_avg:60.59ms step:1777/2225 train_time:107665ms step_avg:60.59ms step:1778/2225 train_time:107725ms step_avg:60.59ms step:1779/2225 train_time:107787ms step_avg:60.59ms step:1780/2225 train_time:107845ms step_avg:60.59ms step:1781/2225 train_time:107906ms step_avg:60.59ms step:1782/2225 train_time:107965ms step_avg:60.59ms step:1783/2225 train_time:108027ms step_avg:60.59ms step:1784/2225 train_time:108088ms step_avg:60.59ms step:1785/2225 train_time:108151ms step_avg:60.59ms step:1786/2225 train_time:108212ms step_avg:60.59ms step:1787/2225 train_time:108274ms step_avg:60.59ms step:1788/2225 train_time:108334ms step_avg:60.59ms step:1789/2225 train_time:108396ms step_avg:60.59ms step:1790/2225 train_time:108456ms step_avg:60.59ms step:1791/2225 train_time:108518ms step_avg:60.59ms step:1792/2225 train_time:108578ms step_avg:60.59ms step:1793/2225 train_time:108640ms step_avg:60.59ms step:1794/2225 train_time:108701ms step_avg:60.59ms step:1795/2225 train_time:108763ms step_avg:60.59ms step:1796/2225 train_time:108823ms step_avg:60.59ms step:1797/2225 train_time:108884ms step_avg:60.59ms step:1798/2225 train_time:108945ms step_avg:60.59ms step:1799/2225 train_time:109007ms step_avg:60.59ms step:1800/2225 train_time:109067ms step_avg:60.59ms step:1801/2225 train_time:109129ms step_avg:60.59ms step:1802/2225 train_time:109189ms step_avg:60.59ms step:1803/2225 train_time:109251ms step_avg:60.59ms step:1804/2225 train_time:109312ms step_avg:60.59ms step:1805/2225 train_time:109373ms step_avg:60.59ms step:1806/2225 train_time:109434ms step_avg:60.59ms step:1807/2225 train_time:109496ms step_avg:60.60ms step:1808/2225 train_time:109555ms step_avg:60.59ms step:1809/2225 train_time:109617ms step_avg:60.60ms step:1810/2225 train_time:109677ms step_avg:60.60ms step:1811/2225 train_time:109739ms step_avg:60.60ms step:1812/2225 train_time:109799ms step_avg:60.60ms step:1813/2225 train_time:109861ms step_avg:60.60ms step:1814/2225 train_time:109921ms step_avg:60.60ms step:1815/2225 train_time:109984ms step_avg:60.60ms step:1816/2225 train_time:110044ms step_avg:60.60ms step:1817/2225 train_time:110106ms step_avg:60.60ms step:1818/2225 train_time:110166ms step_avg:60.60ms step:1819/2225 train_time:110228ms step_avg:60.60ms step:1820/2225 train_time:110288ms step_avg:60.60ms step:1821/2225 train_time:110350ms step_avg:60.60ms step:1822/2225 train_time:110409ms step_avg:60.60ms step:1823/2225 train_time:110471ms step_avg:60.60ms step:1824/2225 train_time:110531ms step_avg:60.60ms step:1825/2225 train_time:110592ms step_avg:60.60ms step:1826/2225 train_time:110652ms step_avg:60.60ms step:1827/2225 train_time:110714ms step_avg:60.60ms step:1828/2225 train_time:110773ms step_avg:60.60ms step:1829/2225 train_time:110836ms step_avg:60.60ms step:1830/2225 train_time:110896ms step_avg:60.60ms step:1831/2225 train_time:110959ms step_avg:60.60ms step:1832/2225 train_time:111020ms step_avg:60.60ms step:1833/2225 train_time:111082ms step_avg:60.60ms step:1834/2225 train_time:111141ms step_avg:60.60ms step:1835/2225 train_time:111204ms step_avg:60.60ms step:1836/2225 train_time:111264ms step_avg:60.60ms step:1837/2225 train_time:111327ms step_avg:60.60ms step:1838/2225 train_time:111386ms step_avg:60.60ms step:1839/2225 train_time:111448ms step_avg:60.60ms step:1840/2225 train_time:111508ms step_avg:60.60ms step:1841/2225 train_time:111569ms step_avg:60.60ms step:1842/2225 train_time:111629ms step_avg:60.60ms step:1843/2225 train_time:111690ms step_avg:60.60ms step:1844/2225 train_time:111751ms step_avg:60.60ms step:1845/2225 train_time:111812ms step_avg:60.60ms step:1846/2225 train_time:111872ms step_avg:60.60ms step:1847/2225 train_time:111934ms step_avg:60.60ms step:1848/2225 train_time:111993ms step_avg:60.60ms step:1849/2225 train_time:112055ms step_avg:60.60ms step:1850/2225 train_time:112116ms step_avg:60.60ms step:1851/2225 train_time:112178ms step_avg:60.60ms step:1852/2225 train_time:112238ms step_avg:60.60ms step:1853/2225 train_time:112301ms step_avg:60.60ms step:1854/2225 train_time:112361ms step_avg:60.60ms step:1855/2225 train_time:112423ms step_avg:60.61ms step:1856/2225 train_time:112484ms step_avg:60.61ms step:1857/2225 train_time:112546ms step_avg:60.61ms step:1858/2225 train_time:112606ms step_avg:60.61ms step:1859/2225 train_time:112667ms step_avg:60.61ms step:1860/2225 train_time:112727ms step_avg:60.61ms step:1861/2225 train_time:112789ms step_avg:60.61ms step:1862/2225 train_time:112848ms step_avg:60.61ms step:1863/2225 train_time:112910ms step_avg:60.61ms step:1864/2225 train_time:112969ms step_avg:60.61ms step:1865/2225 train_time:113031ms step_avg:60.61ms step:1866/2225 train_time:113091ms step_avg:60.61ms step:1867/2225 train_time:113153ms step_avg:60.61ms step:1868/2225 train_time:113213ms step_avg:60.61ms step:1869/2225 train_time:113274ms step_avg:60.61ms step:1870/2225 train_time:113334ms step_avg:60.61ms step:1871/2225 train_time:113397ms step_avg:60.61ms step:1872/2225 train_time:113457ms step_avg:60.61ms step:1873/2225 train_time:113520ms step_avg:60.61ms step:1874/2225 train_time:113580ms step_avg:60.61ms step:1875/2225 train_time:113643ms step_avg:60.61ms step:1876/2225 train_time:113703ms step_avg:60.61ms step:1877/2225 train_time:113765ms step_avg:60.61ms step:1878/2225 train_time:113826ms step_avg:60.61ms step:1879/2225 train_time:113888ms step_avg:60.61ms step:1880/2225 train_time:113947ms step_avg:60.61ms step:1881/2225 train_time:114008ms step_avg:60.61ms step:1882/2225 train_time:114068ms step_avg:60.61ms step:1883/2225 train_time:114130ms step_avg:60.61ms step:1884/2225 train_time:114190ms step_avg:60.61ms step:1885/2225 train_time:114251ms step_avg:60.61ms step:1886/2225 train_time:114311ms step_avg:60.61ms step:1887/2225 train_time:114372ms step_avg:60.61ms step:1888/2225 train_time:114432ms step_avg:60.61ms step:1889/2225 train_time:114495ms step_avg:60.61ms step:1890/2225 train_time:114556ms step_avg:60.61ms step:1891/2225 train_time:114619ms step_avg:60.61ms step:1892/2225 train_time:114679ms step_avg:60.61ms step:1893/2225 train_time:114741ms step_avg:60.61ms step:1894/2225 train_time:114802ms step_avg:60.61ms step:1895/2225 train_time:114865ms step_avg:60.61ms step:1896/2225 train_time:114925ms step_avg:60.61ms step:1897/2225 train_time:114987ms step_avg:60.62ms step:1898/2225 train_time:115047ms step_avg:60.61ms step:1899/2225 train_time:115109ms step_avg:60.62ms step:1900/2225 train_time:115168ms step_avg:60.61ms step:1901/2225 train_time:115230ms step_avg:60.62ms step:1902/2225 train_time:115291ms step_avg:60.62ms step:1903/2225 train_time:115351ms step_avg:60.62ms step:1904/2225 train_time:115411ms step_avg:60.61ms step:1905/2225 train_time:115472ms step_avg:60.62ms step:1906/2225 train_time:115532ms step_avg:60.61ms step:1907/2225 train_time:115593ms step_avg:60.62ms step:1908/2225 train_time:115653ms step_avg:60.61ms step:1909/2225 train_time:115716ms step_avg:60.62ms step:1910/2225 train_time:115776ms step_avg:60.62ms step:1911/2225 train_time:115839ms step_avg:60.62ms step:1912/2225 train_time:115900ms step_avg:60.62ms step:1913/2225 train_time:115962ms step_avg:60.62ms step:1914/2225 train_time:116022ms step_avg:60.62ms step:1915/2225 train_time:116085ms step_avg:60.62ms step:1916/2225 train_time:116145ms step_avg:60.62ms step:1917/2225 train_time:116207ms step_avg:60.62ms step:1918/2225 train_time:116266ms step_avg:60.62ms step:1919/2225 train_time:116328ms step_avg:60.62ms step:1920/2225 train_time:116387ms step_avg:60.62ms step:1921/2225 train_time:116449ms step_avg:60.62ms step:1922/2225 train_time:116510ms step_avg:60.62ms step:1923/2225 train_time:116572ms step_avg:60.62ms step:1924/2225 train_time:116632ms step_avg:60.62ms step:1925/2225 train_time:116694ms step_avg:60.62ms step:1926/2225 train_time:116753ms step_avg:60.62ms step:1927/2225 train_time:116815ms step_avg:60.62ms step:1928/2225 train_time:116875ms step_avg:60.62ms step:1929/2225 train_time:116938ms step_avg:60.62ms step:1930/2225 train_time:116998ms step_avg:60.62ms step:1931/2225 train_time:117061ms step_avg:60.62ms step:1932/2225 train_time:117121ms step_avg:60.62ms step:1933/2225 train_time:117183ms step_avg:60.62ms step:1934/2225 train_time:117243ms step_avg:60.62ms step:1935/2225 train_time:117305ms step_avg:60.62ms step:1936/2225 train_time:117365ms step_avg:60.62ms step:1937/2225 train_time:117427ms step_avg:60.62ms step:1938/2225 train_time:117487ms step_avg:60.62ms step:1939/2225 train_time:117550ms step_avg:60.62ms step:1940/2225 train_time:117610ms step_avg:60.62ms step:1941/2225 train_time:117672ms step_avg:60.62ms step:1942/2225 train_time:117732ms step_avg:60.62ms step:1943/2225 train_time:117794ms step_avg:60.63ms step:1944/2225 train_time:117854ms step_avg:60.62ms step:1945/2225 train_time:117915ms step_avg:60.62ms step:1946/2225 train_time:117975ms step_avg:60.62ms step:1947/2225 train_time:118038ms step_avg:60.63ms step:1948/2225 train_time:118099ms step_avg:60.63ms step:1949/2225 train_time:118161ms step_avg:60.63ms step:1950/2225 train_time:118221ms step_avg:60.63ms step:1951/2225 train_time:118284ms step_avg:60.63ms step:1952/2225 train_time:118344ms step_avg:60.63ms step:1953/2225 train_time:118406ms step_avg:60.63ms step:1954/2225 train_time:118466ms step_avg:60.63ms step:1955/2225 train_time:118528ms step_avg:60.63ms step:1956/2225 train_time:118588ms step_avg:60.63ms step:1957/2225 train_time:118650ms step_avg:60.63ms step:1958/2225 train_time:118710ms step_avg:60.63ms step:1959/2225 train_time:118772ms step_avg:60.63ms step:1960/2225 train_time:118832ms step_avg:60.63ms step:1961/2225 train_time:118893ms step_avg:60.63ms step:1962/2225 train_time:118953ms step_avg:60.63ms step:1963/2225 train_time:119014ms step_avg:60.63ms step:1964/2225 train_time:119074ms step_avg:60.63ms step:1965/2225 train_time:119136ms step_avg:60.63ms step:1966/2225 train_time:119197ms step_avg:60.63ms step:1967/2225 train_time:119259ms step_avg:60.63ms step:1968/2225 train_time:119320ms step_avg:60.63ms step:1969/2225 train_time:119382ms step_avg:60.63ms step:1970/2225 train_time:119442ms step_avg:60.63ms step:1971/2225 train_time:119504ms step_avg:60.63ms step:1972/2225 train_time:119565ms step_avg:60.63ms step:1973/2225 train_time:119627ms step_avg:60.63ms step:1974/2225 train_time:119686ms step_avg:60.63ms step:1975/2225 train_time:119748ms step_avg:60.63ms step:1976/2225 train_time:119808ms step_avg:60.63ms step:1977/2225 train_time:119869ms step_avg:60.63ms step:1978/2225 train_time:119929ms step_avg:60.63ms step:1979/2225 train_time:119991ms step_avg:60.63ms step:1980/2225 train_time:120050ms step_avg:60.63ms step:1981/2225 train_time:120112ms step_avg:60.63ms step:1982/2225 train_time:120172ms step_avg:60.63ms step:1983/2225 train_time:120234ms step_avg:60.63ms step:1984/2225 train_time:120294ms step_avg:60.63ms step:1985/2225 train_time:120356ms step_avg:60.63ms step:1986/2225 train_time:120417ms step_avg:60.63ms step:1987/2225 train_time:120478ms step_avg:60.63ms step:1988/2225 train_time:120539ms step_avg:60.63ms step:1989/2225 train_time:120601ms step_avg:60.63ms step:1990/2225 train_time:120662ms step_avg:60.63ms step:1991/2225 train_time:120724ms step_avg:60.64ms step:1992/2225 train_time:120785ms step_avg:60.63ms step:1993/2225 train_time:120847ms step_avg:60.64ms step:1994/2225 train_time:120906ms step_avg:60.64ms step:1995/2225 train_time:120969ms step_avg:60.64ms step:1996/2225 train_time:121028ms step_avg:60.64ms step:1997/2225 train_time:121090ms step_avg:60.64ms step:1998/2225 train_time:121150ms step_avg:60.64ms step:1999/2225 train_time:121212ms step_avg:60.64ms step:2000/2225 train_time:121272ms step_avg:60.64ms step:2000/2225 val_loss:3.3201 train_time:121334ms step_avg:60.67ms step:2001/2225 train_time:121354ms step_avg:60.65ms step:2002/2225 train_time:121395ms step_avg:60.64ms step:2003/2225 train_time:121462ms step_avg:60.64ms step:2004/2225 train_time:121524ms step_avg:60.64ms step:2005/2225 train_time:121587ms step_avg:60.64ms step:2006/2225 train_time:121646ms step_avg:60.64ms step:2007/2225 train_time:121709ms step_avg:60.64ms step:2008/2225 train_time:121769ms step_avg:60.64ms step:2009/2225 train_time:121830ms step_avg:60.64ms step:2010/2225 train_time:121889ms step_avg:60.64ms step:2011/2225 train_time:121950ms step_avg:60.64ms step:2012/2225 train_time:122010ms step_avg:60.64ms step:2013/2225 train_time:122071ms step_avg:60.64ms step:2014/2225 train_time:122131ms step_avg:60.64ms step:2015/2225 train_time:122193ms step_avg:60.64ms step:2016/2225 train_time:122252ms step_avg:60.64ms step:2017/2225 train_time:122316ms step_avg:60.64ms step:2018/2225 train_time:122378ms step_avg:60.64ms step:2019/2225 train_time:122441ms step_avg:60.64ms step:2020/2225 train_time:122501ms step_avg:60.64ms step:2021/2225 train_time:122564ms step_avg:60.65ms step:2022/2225 train_time:122624ms step_avg:60.64ms step:2023/2225 train_time:122686ms step_avg:60.65ms step:2024/2225 train_time:122746ms step_avg:60.65ms step:2025/2225 train_time:122808ms step_avg:60.65ms step:2026/2225 train_time:122868ms step_avg:60.65ms step:2027/2225 train_time:122929ms step_avg:60.65ms step:2028/2225 train_time:122989ms step_avg:60.65ms step:2029/2225 train_time:123050ms step_avg:60.65ms step:2030/2225 train_time:123110ms step_avg:60.65ms step:2031/2225 train_time:123172ms step_avg:60.65ms step:2032/2225 train_time:123232ms step_avg:60.65ms step:2033/2225 train_time:123295ms step_avg:60.65ms step:2034/2225 train_time:123357ms step_avg:60.65ms step:2035/2225 train_time:123419ms step_avg:60.65ms step:2036/2225 train_time:123480ms step_avg:60.65ms step:2037/2225 train_time:123542ms step_avg:60.65ms step:2038/2225 train_time:123602ms step_avg:60.65ms step:2039/2225 train_time:123664ms step_avg:60.65ms step:2040/2225 train_time:123723ms step_avg:60.65ms step:2041/2225 train_time:123785ms step_avg:60.65ms step:2042/2225 train_time:123844ms step_avg:60.65ms step:2043/2225 train_time:123906ms step_avg:60.65ms step:2044/2225 train_time:123965ms step_avg:60.65ms step:2045/2225 train_time:124027ms step_avg:60.65ms step:2046/2225 train_time:124087ms step_avg:60.65ms step:2047/2225 train_time:124149ms step_avg:60.65ms step:2048/2225 train_time:124210ms step_avg:60.65ms step:2049/2225 train_time:124273ms step_avg:60.65ms step:2050/2225 train_time:124334ms step_avg:60.65ms step:2051/2225 train_time:124396ms step_avg:60.65ms step:2052/2225 train_time:124458ms step_avg:60.65ms step:2053/2225 train_time:124520ms step_avg:60.65ms step:2054/2225 train_time:124580ms step_avg:60.65ms step:2055/2225 train_time:124643ms step_avg:60.65ms step:2056/2225 train_time:124702ms step_avg:60.65ms step:2057/2225 train_time:124764ms step_avg:60.65ms step:2058/2225 train_time:124823ms step_avg:60.65ms step:2059/2225 train_time:124885ms step_avg:60.65ms step:2060/2225 train_time:124944ms step_avg:60.65ms step:2061/2225 train_time:125005ms step_avg:60.65ms step:2062/2225 train_time:125065ms step_avg:60.65ms step:2063/2225 train_time:125127ms step_avg:60.65ms step:2064/2225 train_time:125188ms step_avg:60.65ms step:2065/2225 train_time:125251ms step_avg:60.65ms step:2066/2225 train_time:125312ms step_avg:60.65ms step:2067/2225 train_time:125375ms step_avg:60.66ms step:2068/2225 train_time:125435ms step_avg:60.66ms step:2069/2225 train_time:125498ms step_avg:60.66ms step:2070/2225 train_time:125558ms step_avg:60.66ms step:2071/2225 train_time:125620ms step_avg:60.66ms step:2072/2225 train_time:125679ms step_avg:60.66ms step:2073/2225 train_time:125741ms step_avg:60.66ms step:2074/2225 train_time:125801ms step_avg:60.66ms step:2075/2225 train_time:125863ms step_avg:60.66ms step:2076/2225 train_time:125923ms step_avg:60.66ms step:2077/2225 train_time:125984ms step_avg:60.66ms step:2078/2225 train_time:126043ms step_avg:60.66ms step:2079/2225 train_time:126105ms step_avg:60.66ms step:2080/2225 train_time:126164ms step_avg:60.66ms step:2081/2225 train_time:126226ms step_avg:60.66ms step:2082/2225 train_time:126287ms step_avg:60.66ms step:2083/2225 train_time:126349ms step_avg:60.66ms step:2084/2225 train_time:126410ms step_avg:60.66ms step:2085/2225 train_time:126473ms step_avg:60.66ms step:2086/2225 train_time:126533ms step_avg:60.66ms step:2087/2225 train_time:126596ms step_avg:60.66ms step:2088/2225 train_time:126657ms step_avg:60.66ms step:2089/2225 train_time:126719ms step_avg:60.66ms step:2090/2225 train_time:126778ms step_avg:60.66ms step:2091/2225 train_time:126840ms step_avg:60.66ms step:2092/2225 train_time:126900ms step_avg:60.66ms step:2093/2225 train_time:126962ms step_avg:60.66ms step:2094/2225 train_time:127021ms step_avg:60.66ms step:2095/2225 train_time:127083ms step_avg:60.66ms step:2096/2225 train_time:127142ms step_avg:60.66ms step:2097/2225 train_time:127204ms step_avg:60.66ms step:2098/2225 train_time:127264ms step_avg:60.66ms step:2099/2225 train_time:127326ms step_avg:60.66ms step:2100/2225 train_time:127386ms step_avg:60.66ms step:2101/2225 train_time:127449ms step_avg:60.66ms step:2102/2225 train_time:127509ms step_avg:60.66ms step:2103/2225 train_time:127572ms step_avg:60.66ms step:2104/2225 train_time:127632ms step_avg:60.66ms step:2105/2225 train_time:127695ms step_avg:60.66ms step:2106/2225 train_time:127756ms step_avg:60.66ms step:2107/2225 train_time:127818ms step_avg:60.66ms step:2108/2225 train_time:127878ms step_avg:60.66ms step:2109/2225 train_time:127940ms step_avg:60.66ms step:2110/2225 train_time:128000ms step_avg:60.66ms step:2111/2225 train_time:128062ms step_avg:60.66ms step:2112/2225 train_time:128121ms step_avg:60.66ms step:2113/2225 train_time:128182ms step_avg:60.66ms step:2114/2225 train_time:128242ms step_avg:60.66ms step:2115/2225 train_time:128304ms step_avg:60.66ms step:2116/2225 train_time:128363ms step_avg:60.66ms step:2117/2225 train_time:128425ms step_avg:60.66ms step:2118/2225 train_time:128485ms step_avg:60.66ms step:2119/2225 train_time:128548ms step_avg:60.66ms step:2120/2225 train_time:128609ms step_avg:60.66ms step:2121/2225 train_time:128671ms step_avg:60.67ms step:2122/2225 train_time:128732ms step_avg:60.67ms step:2123/2225 train_time:128794ms step_avg:60.67ms step:2124/2225 train_time:128855ms step_avg:60.67ms step:2125/2225 train_time:128917ms step_avg:60.67ms step:2126/2225 train_time:128977ms step_avg:60.67ms step:2127/2225 train_time:129039ms step_avg:60.67ms step:2128/2225 train_time:129099ms step_avg:60.67ms step:2129/2225 train_time:129160ms step_avg:60.67ms step:2130/2225 train_time:129220ms step_avg:60.67ms step:2131/2225 train_time:129281ms step_avg:60.67ms step:2132/2225 train_time:129341ms step_avg:60.67ms step:2133/2225 train_time:129404ms step_avg:60.67ms step:2134/2225 train_time:129463ms step_avg:60.67ms step:2135/2225 train_time:129525ms step_avg:60.67ms step:2136/2225 train_time:129584ms step_avg:60.67ms step:2137/2225 train_time:129646ms step_avg:60.67ms step:2138/2225 train_time:129706ms step_avg:60.67ms step:2139/2225 train_time:129768ms step_avg:60.67ms step:2140/2225 train_time:129830ms step_avg:60.67ms step:2141/2225 train_time:129892ms step_avg:60.67ms step:2142/2225 train_time:129952ms step_avg:60.67ms step:2143/2225 train_time:130015ms step_avg:60.67ms step:2144/2225 train_time:130076ms step_avg:60.67ms step:2145/2225 train_time:130138ms step_avg:60.67ms step:2146/2225 train_time:130198ms step_avg:60.67ms step:2147/2225 train_time:130260ms step_avg:60.67ms step:2148/2225 train_time:130319ms step_avg:60.67ms step:2149/2225 train_time:130381ms step_avg:60.67ms step:2150/2225 train_time:130441ms step_avg:60.67ms step:2151/2225 train_time:130503ms step_avg:60.67ms step:2152/2225 train_time:130563ms step_avg:60.67ms step:2153/2225 train_time:130625ms step_avg:60.67ms step:2154/2225 train_time:130685ms step_avg:60.67ms step:2155/2225 train_time:130746ms step_avg:60.67ms step:2156/2225 train_time:130806ms step_avg:60.67ms step:2157/2225 train_time:130869ms step_avg:60.67ms step:2158/2225 train_time:130930ms step_avg:60.67ms step:2159/2225 train_time:130992ms step_avg:60.67ms step:2160/2225 train_time:131052ms step_avg:60.67ms step:2161/2225 train_time:131114ms step_avg:60.67ms step:2162/2225 train_time:131175ms step_avg:60.67ms step:2163/2225 train_time:131237ms step_avg:60.67ms step:2164/2225 train_time:131297ms step_avg:60.67ms step:2165/2225 train_time:131359ms step_avg:60.67ms step:2166/2225 train_time:131419ms step_avg:60.67ms step:2167/2225 train_time:131481ms step_avg:60.67ms step:2168/2225 train_time:131541ms step_avg:60.67ms step:2169/2225 train_time:131603ms step_avg:60.67ms step:2170/2225 train_time:131663ms step_avg:60.67ms step:2171/2225 train_time:131725ms step_avg:60.67ms step:2172/2225 train_time:131784ms step_avg:60.67ms step:2173/2225 train_time:131846ms step_avg:60.67ms step:2174/2225 train_time:131906ms step_avg:60.67ms step:2175/2225 train_time:131968ms step_avg:60.68ms step:2176/2225 train_time:132029ms step_avg:60.68ms step:2177/2225 train_time:132091ms step_avg:60.68ms step:2178/2225 train_time:132152ms step_avg:60.68ms step:2179/2225 train_time:132214ms step_avg:60.68ms step:2180/2225 train_time:132275ms step_avg:60.68ms step:2181/2225 train_time:132337ms step_avg:60.68ms step:2182/2225 train_time:132397ms step_avg:60.68ms step:2183/2225 train_time:132460ms step_avg:60.68ms step:2184/2225 train_time:132520ms step_avg:60.68ms step:2185/2225 train_time:132582ms step_avg:60.68ms step:2186/2225 train_time:132642ms step_avg:60.68ms step:2187/2225 train_time:132704ms step_avg:60.68ms step:2188/2225 train_time:132764ms step_avg:60.68ms step:2189/2225 train_time:132825ms step_avg:60.68ms step:2190/2225 train_time:132885ms step_avg:60.68ms step:2191/2225 train_time:132947ms step_avg:60.68ms step:2192/2225 train_time:133007ms step_avg:60.68ms step:2193/2225 train_time:133070ms step_avg:60.68ms step:2194/2225 train_time:133131ms step_avg:60.68ms step:2195/2225 train_time:133193ms step_avg:60.68ms step:2196/2225 train_time:133254ms step_avg:60.68ms step:2197/2225 train_time:133316ms step_avg:60.68ms step:2198/2225 train_time:133377ms step_avg:60.68ms step:2199/2225 train_time:133440ms step_avg:60.68ms step:2200/2225 train_time:133499ms step_avg:60.68ms step:2201/2225 train_time:133561ms step_avg:60.68ms step:2202/2225 train_time:133621ms step_avg:60.68ms step:2203/2225 train_time:133683ms step_avg:60.68ms step:2204/2225 train_time:133742ms step_avg:60.68ms step:2205/2225 train_time:133804ms step_avg:60.68ms step:2206/2225 train_time:133864ms step_avg:60.68ms step:2207/2225 train_time:133926ms step_avg:60.68ms step:2208/2225 train_time:133986ms step_avg:60.68ms step:2209/2225 train_time:134049ms step_avg:60.68ms step:2210/2225 train_time:134110ms step_avg:60.68ms step:2211/2225 train_time:134173ms step_avg:60.68ms step:2212/2225 train_time:134233ms step_avg:60.68ms step:2213/2225 train_time:134296ms step_avg:60.68ms step:2214/2225 train_time:134356ms step_avg:60.68ms step:2215/2225 train_time:134419ms step_avg:60.69ms step:2216/2225 train_time:134478ms step_avg:60.69ms step:2217/2225 train_time:134541ms step_avg:60.69ms step:2218/2225 train_time:134600ms step_avg:60.69ms step:2219/2225 train_time:134662ms step_avg:60.69ms step:2220/2225 train_time:134721ms step_avg:60.69ms step:2221/2225 train_time:134783ms step_avg:60.69ms step:2222/2225 train_time:134843ms step_avg:60.69ms step:2223/2225 train_time:134905ms step_avg:60.69ms step:2224/2225 train_time:134965ms step_avg:60.69ms step:2225/2225 train_time:135026ms step_avg:60.69ms step:2225/2225 val_loss:3.2785 train_time:135088ms step_avg:60.71ms peak memory allocated: 29244 MiB reserved: 43976 MiB