diff --git a/.gitignore b/.gitignore index c0d3296a0..19c517d6a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ fineweb10B/ pylog124M/ __pycache__/ logs/ +.DS_Store diff --git a/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt b/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt deleted file mode 100644 index b5371a4da..000000000 --- a/records/082725_FA3/17e712ee-7cf8-44c9-a784-3762e61b174c.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 04:15:50 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:524ms step_avg:524.12ms -step:2/1695 train_time:549ms step_avg:274.51ms -step:3/1695 train_time:617ms step_avg:205.57ms -step:4/1695 train_time:709ms step_avg:177.26ms -step:5/1695 train_time:803ms step_avg:160.52ms -step:6/1695 train_time:897ms step_avg:149.44ms -step:7/1695 train_time:989ms step_avg:141.33ms -step:8/1695 train_time:1082ms step_avg:135.31ms -step:9/1695 train_time:1177ms step_avg:130.74ms -step:10/1695 train_time:1270ms step_avg:127.00ms -step:11/1695 train_time:1364ms step_avg:123.98ms -step:12/1695 train_time:1460ms step_avg:121.70ms -step:13/1695 train_time:1558ms step_avg:119.84ms -step:14/1695 train_time:1653ms step_avg:118.08ms -step:15/1695 train_time:1748ms step_avg:116.53ms -step:16/1695 train_time:1843ms step_avg:115.16ms -step:17/1695 train_time:1937ms step_avg:113.93ms -step:18/1695 train_time:2030ms step_avg:112.80ms -step:19/1695 train_time:2124ms step_avg:111.79ms -step:20/1695 train_time:2218ms step_avg:110.90ms -step:21/1695 train_time:2312ms step_avg:110.11ms -step:22/1695 train_time:2407ms step_avg:109.43ms -step:23/1695 train_time:2504ms step_avg:108.87ms -step:24/1695 train_time:2600ms step_avg:108.35ms -step:25/1695 train_time:2696ms step_avg:107.83ms -step:26/1695 train_time:2790ms step_avg:107.30ms -step:27/1695 train_time:2884ms step_avg:106.81ms -step:28/1695 train_time:2979ms step_avg:106.39ms -step:29/1695 train_time:3073ms step_avg:105.96ms -step:30/1695 train_time:3167ms step_avg:105.56ms -step:31/1695 train_time:3261ms step_avg:105.19ms -step:32/1695 train_time:3356ms step_avg:104.89ms -step:33/1695 train_time:3451ms step_avg:104.57ms -step:34/1695 train_time:3546ms step_avg:104.31ms -step:35/1695 train_time:3644ms step_avg:104.12ms -step:36/1695 train_time:3740ms step_avg:103.89ms -step:37/1695 train_time:3835ms step_avg:103.64ms -step:38/1695 train_time:3929ms step_avg:103.38ms -step:39/1695 train_time:4023ms step_avg:103.16ms -step:40/1695 train_time:4117ms step_avg:102.94ms -step:41/1695 train_time:4211ms step_avg:102.71ms -step:42/1695 train_time:4305ms step_avg:102.51ms -step:43/1695 train_time:4401ms step_avg:102.34ms -step:44/1695 train_time:4497ms step_avg:102.20ms -step:45/1695 train_time:4592ms step_avg:102.04ms -step:46/1695 train_time:4686ms step_avg:101.87ms -step:47/1695 train_time:4782ms step_avg:101.74ms -step:48/1695 train_time:4877ms step_avg:101.60ms -step:49/1695 train_time:4971ms step_avg:101.45ms -step:50/1695 train_time:5065ms step_avg:101.30ms -step:51/1695 train_time:5160ms step_avg:101.17ms -step:52/1695 train_time:5254ms step_avg:101.03ms -step:53/1695 train_time:5347ms step_avg:100.89ms -step:54/1695 train_time:5442ms step_avg:100.79ms -step:55/1695 train_time:5538ms step_avg:100.69ms -step:56/1695 train_time:5632ms step_avg:100.57ms -step:57/1695 train_time:5726ms step_avg:100.46ms -step:58/1695 train_time:5822ms step_avg:100.38ms -step:59/1695 train_time:5918ms step_avg:100.30ms -step:60/1695 train_time:6012ms step_avg:100.20ms -step:61/1695 train_time:6106ms step_avg:100.10ms -step:62/1695 train_time:6201ms step_avg:100.02ms -step:63/1695 train_time:6296ms step_avg:99.94ms -step:64/1695 train_time:6390ms step_avg:99.84ms -step:65/1695 train_time:6485ms step_avg:99.77ms -step:66/1695 train_time:6579ms step_avg:99.68ms -step:67/1695 train_time:6673ms step_avg:99.59ms -step:68/1695 train_time:6767ms step_avg:99.52ms -step:69/1695 train_time:6863ms step_avg:99.47ms -step:70/1695 train_time:6958ms step_avg:99.40ms -step:71/1695 train_time:7052ms step_avg:99.32ms -step:72/1695 train_time:7146ms step_avg:99.25ms -step:73/1695 train_time:7241ms step_avg:99.19ms -step:74/1695 train_time:7337ms step_avg:99.15ms -step:75/1695 train_time:7431ms step_avg:99.08ms -step:76/1695 train_time:7526ms step_avg:99.02ms -step:77/1695 train_time:7621ms step_avg:98.97ms -step:78/1695 train_time:7716ms step_avg:98.92ms -step:79/1695 train_time:7809ms step_avg:98.85ms -step:80/1695 train_time:7905ms step_avg:98.81ms -step:81/1695 train_time:8000ms step_avg:98.76ms -step:82/1695 train_time:8094ms step_avg:98.70ms -step:83/1695 train_time:8189ms step_avg:98.66ms -step:84/1695 train_time:8283ms step_avg:98.60ms -step:85/1695 train_time:8378ms step_avg:98.56ms -step:86/1695 train_time:8471ms step_avg:98.50ms -step:87/1695 train_time:8566ms step_avg:98.45ms -step:88/1695 train_time:8661ms step_avg:98.42ms -step:89/1695 train_time:8755ms step_avg:98.37ms -step:90/1695 train_time:8849ms step_avg:98.32ms -step:91/1695 train_time:8944ms step_avg:98.29ms -step:92/1695 train_time:9039ms step_avg:98.25ms -step:93/1695 train_time:9133ms step_avg:98.21ms -step:94/1695 train_time:9227ms step_avg:98.16ms -step:95/1695 train_time:9322ms step_avg:98.13ms -step:96/1695 train_time:9417ms step_avg:98.10ms -step:97/1695 train_time:9511ms step_avg:98.05ms -step:98/1695 train_time:9606ms step_avg:98.02ms -step:99/1695 train_time:9702ms step_avg:98.00ms -step:100/1695 train_time:9797ms step_avg:97.97ms -step:101/1695 train_time:9891ms step_avg:97.93ms -step:102/1695 train_time:9985ms step_avg:97.89ms -step:103/1695 train_time:10079ms step_avg:97.85ms -step:104/1695 train_time:10174ms step_avg:97.82ms -step:105/1695 train_time:10268ms step_avg:97.79ms -step:106/1695 train_time:10362ms step_avg:97.76ms -step:107/1695 train_time:10456ms step_avg:97.72ms -step:108/1695 train_time:10550ms step_avg:97.69ms -step:109/1695 train_time:10645ms step_avg:97.66ms -step:110/1695 train_time:10740ms step_avg:97.63ms -step:111/1695 train_time:10835ms step_avg:97.61ms -step:112/1695 train_time:10929ms step_avg:97.58ms -step:113/1695 train_time:11023ms step_avg:97.55ms -step:114/1695 train_time:11118ms step_avg:97.53ms -step:115/1695 train_time:11213ms step_avg:97.50ms -step:116/1695 train_time:11307ms step_avg:97.47ms -step:117/1695 train_time:11402ms step_avg:97.45ms -step:118/1695 train_time:11496ms step_avg:97.42ms -step:119/1695 train_time:11589ms step_avg:97.39ms -step:120/1695 train_time:11684ms step_avg:97.37ms -step:121/1695 train_time:11779ms step_avg:97.34ms -step:122/1695 train_time:11873ms step_avg:97.32ms -step:123/1695 train_time:11967ms step_avg:97.29ms -step:124/1695 train_time:12062ms step_avg:97.28ms -step:125/1695 train_time:12157ms step_avg:97.25ms -step:125/1695 val_loss:4.3113 train_time:12248ms step_avg:97.99ms -step:126/1695 train_time:12274ms step_avg:97.41ms -step:127/1695 train_time:12351ms step_avg:97.25ms -step:128/1695 train_time:12451ms step_avg:97.28ms -step:129/1695 train_time:12547ms step_avg:97.26ms -step:130/1695 train_time:12640ms step_avg:97.23ms -step:131/1695 train_time:12734ms step_avg:97.21ms -step:132/1695 train_time:12828ms step_avg:97.18ms -step:133/1695 train_time:12921ms step_avg:97.15ms -step:134/1695 train_time:13014ms step_avg:97.12ms -step:135/1695 train_time:13108ms step_avg:97.09ms -step:136/1695 train_time:13201ms step_avg:97.07ms -step:137/1695 train_time:13297ms step_avg:97.05ms -step:138/1695 train_time:13394ms step_avg:97.06ms -step:139/1695 train_time:13490ms step_avg:97.05ms -step:140/1695 train_time:13584ms step_avg:97.03ms -step:141/1695 train_time:13678ms step_avg:97.00ms -step:142/1695 train_time:13772ms step_avg:96.98ms -step:143/1695 train_time:13865ms step_avg:96.96ms -step:144/1695 train_time:13958ms step_avg:96.93ms -step:145/1695 train_time:14052ms step_avg:96.91ms -step:146/1695 train_time:14144ms step_avg:96.88ms -step:147/1695 train_time:14238ms step_avg:96.86ms -step:148/1695 train_time:14333ms step_avg:96.84ms -step:149/1695 train_time:14429ms step_avg:96.84ms -step:150/1695 train_time:14523ms step_avg:96.82ms -step:151/1695 train_time:14618ms step_avg:96.81ms -step:152/1695 train_time:14713ms step_avg:96.79ms -step:153/1695 train_time:14806ms step_avg:96.77ms -step:154/1695 train_time:14900ms step_avg:96.75ms -step:155/1695 train_time:14995ms step_avg:96.74ms -step:156/1695 train_time:15088ms step_avg:96.72ms -step:157/1695 train_time:15181ms step_avg:96.69ms -step:158/1695 train_time:15274ms step_avg:96.67ms -step:159/1695 train_time:15369ms step_avg:96.66ms -step:160/1695 train_time:15464ms step_avg:96.65ms -step:161/1695 train_time:15558ms step_avg:96.63ms -step:162/1695 train_time:15653ms step_avg:96.62ms -step:163/1695 train_time:15748ms step_avg:96.61ms -step:164/1695 train_time:15841ms step_avg:96.59ms -step:165/1695 train_time:15936ms step_avg:96.58ms -step:166/1695 train_time:16030ms step_avg:96.57ms -step:167/1695 train_time:16124ms step_avg:96.55ms -step:168/1695 train_time:16217ms step_avg:96.53ms -step:169/1695 train_time:16311ms step_avg:96.52ms -step:170/1695 train_time:16406ms step_avg:96.50ms -step:171/1695 train_time:16500ms step_avg:96.49ms -step:172/1695 train_time:16595ms step_avg:96.48ms -step:173/1695 train_time:16964ms step_avg:98.06ms -step:174/1695 train_time:17033ms step_avg:97.89ms -step:175/1695 train_time:17126ms step_avg:97.86ms -step:176/1695 train_time:17219ms step_avg:97.83ms -step:177/1695 train_time:17312ms step_avg:97.81ms -step:178/1695 train_time:17405ms step_avg:97.78ms -step:179/1695 train_time:17499ms step_avg:97.76ms -step:180/1695 train_time:17591ms step_avg:97.73ms -step:181/1695 train_time:17684ms step_avg:97.70ms -step:182/1695 train_time:17777ms step_avg:97.68ms -step:183/1695 train_time:17876ms step_avg:97.68ms -step:184/1695 train_time:17973ms step_avg:97.68ms -step:185/1695 train_time:18068ms step_avg:97.67ms -step:186/1695 train_time:18162ms step_avg:97.64ms -step:187/1695 train_time:18256ms step_avg:97.63ms -step:188/1695 train_time:18350ms step_avg:97.61ms -step:189/1695 train_time:18443ms step_avg:97.58ms -step:190/1695 train_time:18537ms step_avg:97.56ms -step:191/1695 train_time:18630ms step_avg:97.54ms -step:192/1695 train_time:18724ms step_avg:97.52ms -step:193/1695 train_time:18818ms step_avg:97.50ms -step:194/1695 train_time:18914ms step_avg:97.49ms -step:195/1695 train_time:19010ms step_avg:97.48ms -step:196/1695 train_time:19105ms step_avg:97.47ms -step:197/1695 train_time:19198ms step_avg:97.45ms -step:198/1695 train_time:19292ms step_avg:97.44ms -step:199/1695 train_time:19387ms step_avg:97.42ms -step:200/1695 train_time:19480ms step_avg:97.40ms -step:201/1695 train_time:19574ms step_avg:97.39ms -step:202/1695 train_time:19669ms step_avg:97.37ms -step:203/1695 train_time:19763ms step_avg:97.35ms -step:204/1695 train_time:19857ms step_avg:97.34ms -step:205/1695 train_time:19952ms step_avg:97.32ms -step:206/1695 train_time:20047ms step_avg:97.31ms -step:207/1695 train_time:20140ms step_avg:97.30ms -step:208/1695 train_time:20235ms step_avg:97.28ms -step:209/1695 train_time:20329ms step_avg:97.27ms -step:210/1695 train_time:20423ms step_avg:97.25ms -step:211/1695 train_time:20516ms step_avg:97.23ms -step:212/1695 train_time:20611ms step_avg:97.22ms -step:213/1695 train_time:20706ms step_avg:97.21ms -step:214/1695 train_time:20799ms step_avg:97.19ms -step:215/1695 train_time:20894ms step_avg:97.18ms -step:216/1695 train_time:20988ms step_avg:97.17ms -step:217/1695 train_time:21082ms step_avg:97.15ms -step:218/1695 train_time:21176ms step_avg:97.14ms -step:219/1695 train_time:21271ms step_avg:97.13ms -step:220/1695 train_time:21365ms step_avg:97.12ms -step:221/1695 train_time:21459ms step_avg:97.10ms -step:222/1695 train_time:21554ms step_avg:97.09ms -step:223/1695 train_time:21648ms step_avg:97.07ms -step:224/1695 train_time:21741ms step_avg:97.06ms -step:225/1695 train_time:21835ms step_avg:97.05ms -step:226/1695 train_time:21931ms step_avg:97.04ms -step:227/1695 train_time:22024ms step_avg:97.02ms -step:228/1695 train_time:22118ms step_avg:97.01ms -step:229/1695 train_time:22213ms step_avg:97.00ms -step:230/1695 train_time:22308ms step_avg:96.99ms -step:231/1695 train_time:22401ms step_avg:96.98ms -step:232/1695 train_time:22496ms step_avg:96.96ms -step:233/1695 train_time:22589ms step_avg:96.95ms -step:234/1695 train_time:22683ms step_avg:96.93ms -step:235/1695 train_time:22776ms step_avg:96.92ms -step:236/1695 train_time:22872ms step_avg:96.91ms -step:237/1695 train_time:22967ms step_avg:96.91ms -step:238/1695 train_time:23061ms step_avg:96.89ms -step:239/1695 train_time:23155ms step_avg:96.88ms -step:240/1695 train_time:23249ms step_avg:96.87ms -step:241/1695 train_time:23343ms step_avg:96.86ms -step:242/1695 train_time:23437ms step_avg:96.85ms -step:243/1695 train_time:23532ms step_avg:96.84ms -step:244/1695 train_time:23626ms step_avg:96.83ms -step:245/1695 train_time:23719ms step_avg:96.81ms -step:246/1695 train_time:23814ms step_avg:96.81ms -step:247/1695 train_time:23909ms step_avg:96.80ms -step:248/1695 train_time:24003ms step_avg:96.79ms -step:249/1695 train_time:24097ms step_avg:96.78ms -step:250/1695 train_time:24191ms step_avg:96.77ms -step:250/1695 val_loss:3.9807 train_time:24284ms step_avg:97.14ms -step:251/1695 train_time:24310ms step_avg:96.85ms -step:252/1695 train_time:24384ms step_avg:96.76ms -step:253/1695 train_time:24484ms step_avg:96.78ms -step:254/1695 train_time:24580ms step_avg:96.77ms -step:255/1695 train_time:24673ms step_avg:96.76ms -step:256/1695 train_time:24766ms step_avg:96.74ms -step:257/1695 train_time:24859ms step_avg:96.73ms -step:258/1695 train_time:24953ms step_avg:96.72ms -step:259/1695 train_time:25046ms step_avg:96.70ms -step:260/1695 train_time:25139ms step_avg:96.69ms -step:261/1695 train_time:25232ms step_avg:96.68ms -step:262/1695 train_time:25328ms step_avg:96.67ms -step:263/1695 train_time:25425ms step_avg:96.67ms -step:264/1695 train_time:25521ms step_avg:96.67ms -step:265/1695 train_time:25616ms step_avg:96.67ms -step:266/1695 train_time:25710ms step_avg:96.65ms -step:267/1695 train_time:25804ms step_avg:96.64ms -step:268/1695 train_time:25898ms step_avg:96.64ms -step:269/1695 train_time:25992ms step_avg:96.62ms -step:270/1695 train_time:26085ms step_avg:96.61ms -step:271/1695 train_time:26179ms step_avg:96.60ms -step:272/1695 train_time:26273ms step_avg:96.59ms -step:273/1695 train_time:26368ms step_avg:96.59ms -step:274/1695 train_time:26465ms step_avg:96.59ms -step:275/1695 train_time:26560ms step_avg:96.58ms -step:276/1695 train_time:26655ms step_avg:96.58ms -step:277/1695 train_time:26749ms step_avg:96.57ms -step:278/1695 train_time:26843ms step_avg:96.56ms -step:279/1695 train_time:26936ms step_avg:96.55ms -step:280/1695 train_time:27030ms step_avg:96.53ms -step:281/1695 train_time:27124ms step_avg:96.53ms -step:282/1695 train_time:27218ms step_avg:96.52ms -step:283/1695 train_time:27312ms step_avg:96.51ms -step:284/1695 train_time:27407ms step_avg:96.50ms -step:285/1695 train_time:27502ms step_avg:96.50ms -step:286/1695 train_time:27596ms step_avg:96.49ms -step:287/1695 train_time:27690ms step_avg:96.48ms -step:288/1695 train_time:27785ms step_avg:96.47ms -step:289/1695 train_time:27879ms step_avg:96.47ms -step:290/1695 train_time:27971ms step_avg:96.45ms -step:291/1695 train_time:28065ms step_avg:96.44ms -step:292/1695 train_time:28160ms step_avg:96.44ms -step:293/1695 train_time:28254ms step_avg:96.43ms -step:294/1695 train_time:28348ms step_avg:96.42ms -step:295/1695 train_time:28443ms step_avg:96.42ms -step:296/1695 train_time:28538ms step_avg:96.41ms -step:297/1695 train_time:28632ms step_avg:96.40ms -step:298/1695 train_time:28726ms step_avg:96.40ms -step:299/1695 train_time:28820ms step_avg:96.39ms -step:300/1695 train_time:28913ms step_avg:96.38ms -step:301/1695 train_time:29006ms step_avg:96.37ms -step:302/1695 train_time:29100ms step_avg:96.36ms -step:303/1695 train_time:29194ms step_avg:96.35ms -step:304/1695 train_time:29288ms step_avg:96.34ms -step:305/1695 train_time:29382ms step_avg:96.33ms -step:306/1695 train_time:29477ms step_avg:96.33ms -step:307/1695 train_time:29571ms step_avg:96.32ms -step:308/1695 train_time:29666ms step_avg:96.32ms -step:309/1695 train_time:29761ms step_avg:96.31ms -step:310/1695 train_time:29855ms step_avg:96.31ms -step:311/1695 train_time:29948ms step_avg:96.30ms -step:312/1695 train_time:30042ms step_avg:96.29ms -step:313/1695 train_time:30135ms step_avg:96.28ms -step:314/1695 train_time:30229ms step_avg:96.27ms -step:315/1695 train_time:30322ms step_avg:96.26ms -step:316/1695 train_time:30416ms step_avg:96.25ms -step:317/1695 train_time:30510ms step_avg:96.24ms -step:318/1695 train_time:30605ms step_avg:96.24ms -step:319/1695 train_time:30700ms step_avg:96.24ms -step:320/1695 train_time:30795ms step_avg:96.23ms -step:321/1695 train_time:30888ms step_avg:96.23ms -step:322/1695 train_time:30983ms step_avg:96.22ms -step:323/1695 train_time:31077ms step_avg:96.21ms -step:324/1695 train_time:31170ms step_avg:96.20ms -step:325/1695 train_time:31264ms step_avg:96.20ms -step:326/1695 train_time:31360ms step_avg:96.20ms -step:327/1695 train_time:31453ms step_avg:96.19ms -step:328/1695 train_time:31547ms step_avg:96.18ms -step:329/1695 train_time:31642ms step_avg:96.18ms -step:330/1695 train_time:31737ms step_avg:96.17ms -step:331/1695 train_time:31832ms step_avg:96.17ms -step:332/1695 train_time:31926ms step_avg:96.16ms -step:333/1695 train_time:32020ms step_avg:96.16ms -step:334/1695 train_time:32114ms step_avg:96.15ms -step:335/1695 train_time:32207ms step_avg:96.14ms -step:336/1695 train_time:32302ms step_avg:96.14ms -step:337/1695 train_time:32395ms step_avg:96.13ms -step:338/1695 train_time:32488ms step_avg:96.12ms -step:339/1695 train_time:32582ms step_avg:96.11ms -step:340/1695 train_time:32677ms step_avg:96.11ms -step:341/1695 train_time:32771ms step_avg:96.10ms -step:342/1695 train_time:32866ms step_avg:96.10ms -step:343/1695 train_time:32961ms step_avg:96.10ms -step:344/1695 train_time:33055ms step_avg:96.09ms -step:345/1695 train_time:33378ms step_avg:96.75ms -step:346/1695 train_time:33470ms step_avg:96.73ms -step:347/1695 train_time:33563ms step_avg:96.72ms -step:348/1695 train_time:33655ms step_avg:96.71ms -step:349/1695 train_time:33748ms step_avg:96.70ms -step:350/1695 train_time:33841ms step_avg:96.69ms -step:351/1695 train_time:33934ms step_avg:96.68ms -step:352/1695 train_time:34027ms step_avg:96.67ms -step:353/1695 train_time:34120ms step_avg:96.66ms -step:354/1695 train_time:34213ms step_avg:96.65ms -step:355/1695 train_time:34312ms step_avg:96.65ms -step:356/1695 train_time:34409ms step_avg:96.65ms -step:357/1695 train_time:34506ms step_avg:96.66ms -step:358/1695 train_time:34602ms step_avg:96.65ms -step:359/1695 train_time:34695ms step_avg:96.64ms -step:360/1695 train_time:34788ms step_avg:96.63ms -step:361/1695 train_time:34881ms step_avg:96.62ms -step:362/1695 train_time:34974ms step_avg:96.61ms -step:363/1695 train_time:35068ms step_avg:96.60ms -step:364/1695 train_time:35161ms step_avg:96.60ms -step:365/1695 train_time:35256ms step_avg:96.59ms -step:366/1695 train_time:35351ms step_avg:96.59ms -step:367/1695 train_time:35448ms step_avg:96.59ms -step:368/1695 train_time:35544ms step_avg:96.59ms -step:369/1695 train_time:35639ms step_avg:96.58ms -step:370/1695 train_time:35733ms step_avg:96.58ms -step:371/1695 train_time:35826ms step_avg:96.57ms -step:372/1695 train_time:35920ms step_avg:96.56ms -step:373/1695 train_time:36013ms step_avg:96.55ms -step:374/1695 train_time:36106ms step_avg:96.54ms -step:375/1695 train_time:36199ms step_avg:96.53ms -step:375/1695 val_loss:3.8148 train_time:36291ms step_avg:96.78ms -step:376/1695 train_time:36317ms step_avg:96.59ms -step:377/1695 train_time:36395ms step_avg:96.54ms -step:378/1695 train_time:36491ms step_avg:96.54ms -step:379/1695 train_time:36586ms step_avg:96.53ms -step:380/1695 train_time:36680ms step_avg:96.53ms -step:381/1695 train_time:36773ms step_avg:96.52ms -step:382/1695 train_time:36866ms step_avg:96.51ms -step:383/1695 train_time:36960ms step_avg:96.50ms -step:384/1695 train_time:37054ms step_avg:96.49ms -step:385/1695 train_time:37147ms step_avg:96.49ms -step:386/1695 train_time:37241ms step_avg:96.48ms -step:387/1695 train_time:37337ms step_avg:96.48ms -step:388/1695 train_time:37433ms step_avg:96.48ms -step:389/1695 train_time:37529ms step_avg:96.47ms -step:390/1695 train_time:37623ms step_avg:96.47ms -step:391/1695 train_time:37716ms step_avg:96.46ms -step:392/1695 train_time:37809ms step_avg:96.45ms -step:393/1695 train_time:37902ms step_avg:96.44ms -step:394/1695 train_time:37996ms step_avg:96.44ms -step:395/1695 train_time:38089ms step_avg:96.43ms -step:396/1695 train_time:38183ms step_avg:96.42ms -step:397/1695 train_time:38277ms step_avg:96.42ms -step:398/1695 train_time:38371ms step_avg:96.41ms -step:399/1695 train_time:38466ms step_avg:96.41ms -step:400/1695 train_time:38561ms step_avg:96.40ms -step:401/1695 train_time:38655ms step_avg:96.40ms -step:402/1695 train_time:38749ms step_avg:96.39ms -step:403/1695 train_time:38843ms step_avg:96.38ms -step:404/1695 train_time:38936ms step_avg:96.38ms -step:405/1695 train_time:39030ms step_avg:96.37ms -step:406/1695 train_time:39124ms step_avg:96.36ms -step:407/1695 train_time:39217ms step_avg:96.36ms -step:408/1695 train_time:39311ms step_avg:96.35ms -step:409/1695 train_time:39406ms step_avg:96.35ms -step:410/1695 train_time:39501ms step_avg:96.34ms -step:411/1695 train_time:39594ms step_avg:96.34ms -step:412/1695 train_time:39689ms step_avg:96.33ms -step:413/1695 train_time:39783ms step_avg:96.33ms -step:414/1695 train_time:39876ms step_avg:96.32ms -step:415/1695 train_time:39970ms step_avg:96.31ms -step:416/1695 train_time:40063ms step_avg:96.31ms -step:417/1695 train_time:40157ms step_avg:96.30ms -step:418/1695 train_time:40251ms step_avg:96.29ms -step:419/1695 train_time:40346ms step_avg:96.29ms -step:420/1695 train_time:40440ms step_avg:96.29ms -step:421/1695 train_time:40534ms step_avg:96.28ms -step:422/1695 train_time:40629ms step_avg:96.28ms -step:423/1695 train_time:40723ms step_avg:96.27ms -step:424/1695 train_time:40816ms step_avg:96.26ms -step:425/1695 train_time:40910ms step_avg:96.26ms -step:426/1695 train_time:41005ms step_avg:96.25ms -step:427/1695 train_time:41099ms step_avg:96.25ms -step:428/1695 train_time:41192ms step_avg:96.24ms -step:429/1695 train_time:41288ms step_avg:96.24ms -step:430/1695 train_time:41383ms step_avg:96.24ms -step:431/1695 train_time:41476ms step_avg:96.23ms -step:432/1695 train_time:41570ms step_avg:96.23ms -step:433/1695 train_time:41664ms step_avg:96.22ms -step:434/1695 train_time:41758ms step_avg:96.22ms -step:435/1695 train_time:41851ms step_avg:96.21ms -step:436/1695 train_time:41946ms step_avg:96.21ms -step:437/1695 train_time:42039ms step_avg:96.20ms -step:438/1695 train_time:42132ms step_avg:96.19ms -step:439/1695 train_time:42227ms step_avg:96.19ms -step:440/1695 train_time:42322ms step_avg:96.19ms -step:441/1695 train_time:42416ms step_avg:96.18ms -step:442/1695 train_time:42510ms step_avg:96.18ms -step:443/1695 train_time:42605ms step_avg:96.17ms -step:444/1695 train_time:42698ms step_avg:96.17ms -step:445/1695 train_time:42792ms step_avg:96.16ms -step:446/1695 train_time:42887ms step_avg:96.16ms -step:447/1695 train_time:42982ms step_avg:96.16ms -step:448/1695 train_time:43075ms step_avg:96.15ms -step:449/1695 train_time:43169ms step_avg:96.14ms -step:450/1695 train_time:43264ms step_avg:96.14ms -step:451/1695 train_time:43358ms step_avg:96.14ms -step:452/1695 train_time:43452ms step_avg:96.13ms -step:453/1695 train_time:43548ms step_avg:96.13ms -step:454/1695 train_time:43642ms step_avg:96.13ms -step:455/1695 train_time:43735ms step_avg:96.12ms -step:456/1695 train_time:43828ms step_avg:96.11ms -step:457/1695 train_time:43923ms step_avg:96.11ms -step:458/1695 train_time:44018ms step_avg:96.11ms -step:459/1695 train_time:44112ms step_avg:96.10ms -step:460/1695 train_time:44205ms step_avg:96.10ms -step:461/1695 train_time:44300ms step_avg:96.09ms -step:462/1695 train_time:44393ms step_avg:96.09ms -step:463/1695 train_time:44487ms step_avg:96.08ms -step:464/1695 train_time:44582ms step_avg:96.08ms -step:465/1695 train_time:44676ms step_avg:96.08ms -step:466/1695 train_time:44770ms step_avg:96.07ms -step:467/1695 train_time:44864ms step_avg:96.07ms -step:468/1695 train_time:44959ms step_avg:96.07ms -step:469/1695 train_time:45053ms step_avg:96.06ms -step:470/1695 train_time:45147ms step_avg:96.06ms -step:471/1695 train_time:45242ms step_avg:96.05ms -step:472/1695 train_time:45335ms step_avg:96.05ms -step:473/1695 train_time:45429ms step_avg:96.04ms -step:474/1695 train_time:45523ms step_avg:96.04ms -step:475/1695 train_time:45616ms step_avg:96.03ms -step:476/1695 train_time:45710ms step_avg:96.03ms -step:477/1695 train_time:45805ms step_avg:96.03ms -step:478/1695 train_time:45899ms step_avg:96.02ms -step:479/1695 train_time:45992ms step_avg:96.02ms -step:480/1695 train_time:46087ms step_avg:96.01ms -step:481/1695 train_time:46181ms step_avg:96.01ms -step:482/1695 train_time:46275ms step_avg:96.01ms -step:483/1695 train_time:46369ms step_avg:96.00ms -step:484/1695 train_time:46464ms step_avg:96.00ms -step:485/1695 train_time:46559ms step_avg:96.00ms -step:486/1695 train_time:46653ms step_avg:95.99ms -step:487/1695 train_time:46747ms step_avg:95.99ms -step:488/1695 train_time:46842ms step_avg:95.99ms -step:489/1695 train_time:46935ms step_avg:95.98ms -step:490/1695 train_time:47029ms step_avg:95.98ms -step:491/1695 train_time:47122ms step_avg:95.97ms -step:492/1695 train_time:47216ms step_avg:95.97ms -step:493/1695 train_time:47309ms step_avg:95.96ms -step:494/1695 train_time:47403ms step_avg:95.96ms -step:495/1695 train_time:47496ms step_avg:95.95ms -step:496/1695 train_time:47591ms step_avg:95.95ms -step:497/1695 train_time:47686ms step_avg:95.95ms -step:498/1695 train_time:47780ms step_avg:95.94ms -step:499/1695 train_time:47874ms step_avg:95.94ms -step:500/1695 train_time:47968ms step_avg:95.94ms -step:500/1695 val_loss:3.7151 train_time:48060ms step_avg:96.12ms -step:501/1695 train_time:48087ms step_avg:95.98ms -step:502/1695 train_time:48163ms step_avg:95.94ms -step:503/1695 train_time:48261ms step_avg:95.95ms -step:504/1695 train_time:48355ms step_avg:95.94ms -step:505/1695 train_time:48448ms step_avg:95.94ms -step:506/1695 train_time:48542ms step_avg:95.93ms -step:507/1695 train_time:48634ms step_avg:95.93ms -step:508/1695 train_time:48728ms step_avg:95.92ms -step:509/1695 train_time:48820ms step_avg:95.91ms -step:510/1695 train_time:48913ms step_avg:95.91ms -step:511/1695 train_time:49007ms step_avg:95.90ms -step:512/1695 train_time:49103ms step_avg:95.90ms -step:513/1695 train_time:49198ms step_avg:95.90ms -step:514/1695 train_time:49293ms step_avg:95.90ms -step:515/1695 train_time:49388ms step_avg:95.90ms -step:516/1695 train_time:49482ms step_avg:95.90ms -step:517/1695 train_time:49575ms step_avg:95.89ms -step:518/1695 train_time:49669ms step_avg:95.89ms -step:519/1695 train_time:50009ms step_avg:96.36ms -step:520/1695 train_time:50195ms step_avg:96.53ms -step:521/1695 train_time:50287ms step_avg:96.52ms -step:522/1695 train_time:50380ms step_avg:96.51ms -step:523/1695 train_time:50473ms step_avg:96.51ms -step:524/1695 train_time:50566ms step_avg:96.50ms -step:525/1695 train_time:50658ms step_avg:96.49ms -step:526/1695 train_time:50751ms step_avg:96.49ms -step:527/1695 train_time:50845ms step_avg:96.48ms -step:528/1695 train_time:50937ms step_avg:96.47ms -step:529/1695 train_time:51035ms step_avg:96.47ms -step:530/1695 train_time:51134ms step_avg:96.48ms -step:531/1695 train_time:51232ms step_avg:96.48ms -step:532/1695 train_time:51328ms step_avg:96.48ms -step:533/1695 train_time:51422ms step_avg:96.48ms -step:534/1695 train_time:51515ms step_avg:96.47ms -step:535/1695 train_time:51609ms step_avg:96.47ms -step:536/1695 train_time:51703ms step_avg:96.46ms -step:537/1695 train_time:51795ms step_avg:96.45ms -step:538/1695 train_time:51889ms step_avg:96.45ms -step:539/1695 train_time:51982ms step_avg:96.44ms -step:540/1695 train_time:52077ms step_avg:96.44ms -step:541/1695 train_time:52173ms step_avg:96.44ms -step:542/1695 train_time:52270ms step_avg:96.44ms -step:543/1695 train_time:52365ms step_avg:96.44ms -step:544/1695 train_time:52458ms step_avg:96.43ms -step:545/1695 train_time:52552ms step_avg:96.43ms -step:546/1695 train_time:52646ms step_avg:96.42ms -step:547/1695 train_time:52739ms step_avg:96.42ms -step:548/1695 train_time:52833ms step_avg:96.41ms -step:549/1695 train_time:52926ms step_avg:96.40ms -step:550/1695 train_time:53020ms step_avg:96.40ms -step:551/1695 train_time:53114ms step_avg:96.40ms -step:552/1695 train_time:53209ms step_avg:96.39ms -step:553/1695 train_time:53303ms step_avg:96.39ms -step:554/1695 train_time:53397ms step_avg:96.38ms -step:555/1695 train_time:53491ms step_avg:96.38ms -step:556/1695 train_time:53584ms step_avg:96.37ms -step:557/1695 train_time:53678ms step_avg:96.37ms -step:558/1695 train_time:53771ms step_avg:96.36ms -step:559/1695 train_time:53865ms step_avg:96.36ms -step:560/1695 train_time:53959ms step_avg:96.35ms -step:561/1695 train_time:54053ms step_avg:96.35ms -step:562/1695 train_time:54147ms step_avg:96.35ms -step:563/1695 train_time:54242ms step_avg:96.34ms -step:564/1695 train_time:54336ms step_avg:96.34ms -step:565/1695 train_time:54430ms step_avg:96.34ms -step:566/1695 train_time:54524ms step_avg:96.33ms -step:567/1695 train_time:54618ms step_avg:96.33ms -step:568/1695 train_time:54713ms step_avg:96.33ms -step:569/1695 train_time:54809ms step_avg:96.32ms -step:570/1695 train_time:54905ms step_avg:96.33ms -step:571/1695 train_time:55001ms step_avg:96.32ms -step:572/1695 train_time:55096ms step_avg:96.32ms -step:573/1695 train_time:55193ms step_avg:96.32ms -step:574/1695 train_time:55289ms step_avg:96.32ms -step:575/1695 train_time:55385ms step_avg:96.32ms -step:576/1695 train_time:55481ms step_avg:96.32ms -step:577/1695 train_time:55576ms step_avg:96.32ms -step:578/1695 train_time:55672ms step_avg:96.32ms -step:579/1695 train_time:55769ms step_avg:96.32ms -step:580/1695 train_time:55865ms step_avg:96.32ms -step:581/1695 train_time:55961ms step_avg:96.32ms -step:582/1695 train_time:56056ms step_avg:96.32ms -step:583/1695 train_time:56152ms step_avg:96.32ms -step:584/1695 train_time:56248ms step_avg:96.32ms -step:585/1695 train_time:56345ms step_avg:96.32ms -step:586/1695 train_time:56442ms step_avg:96.32ms -step:587/1695 train_time:56537ms step_avg:96.32ms -step:588/1695 train_time:56633ms step_avg:96.31ms -step:589/1695 train_time:56729ms step_avg:96.31ms -step:590/1695 train_time:56824ms step_avg:96.31ms -step:591/1695 train_time:56919ms step_avg:96.31ms -step:592/1695 train_time:57015ms step_avg:96.31ms -step:593/1695 train_time:57111ms step_avg:96.31ms -step:594/1695 train_time:57208ms step_avg:96.31ms -step:595/1695 train_time:57304ms step_avg:96.31ms -step:596/1695 train_time:57401ms step_avg:96.31ms -step:597/1695 train_time:57496ms step_avg:96.31ms -step:598/1695 train_time:57592ms step_avg:96.31ms -step:599/1695 train_time:57689ms step_avg:96.31ms -step:600/1695 train_time:57784ms step_avg:96.31ms -step:601/1695 train_time:57880ms step_avg:96.31ms -step:602/1695 train_time:57976ms step_avg:96.31ms -step:603/1695 train_time:58071ms step_avg:96.30ms -step:604/1695 train_time:58168ms step_avg:96.30ms -step:605/1695 train_time:58264ms step_avg:96.30ms -step:606/1695 train_time:58360ms step_avg:96.30ms -step:607/1695 train_time:58455ms step_avg:96.30ms -step:608/1695 train_time:58550ms step_avg:96.30ms -step:609/1695 train_time:58647ms step_avg:96.30ms -step:610/1695 train_time:58744ms step_avg:96.30ms -step:611/1695 train_time:58841ms step_avg:96.30ms -step:612/1695 train_time:58936ms step_avg:96.30ms -step:613/1695 train_time:59032ms step_avg:96.30ms -step:614/1695 train_time:59128ms step_avg:96.30ms -step:615/1695 train_time:59224ms step_avg:96.30ms -step:616/1695 train_time:59320ms step_avg:96.30ms -step:617/1695 train_time:59415ms step_avg:96.30ms -step:618/1695 train_time:59511ms step_avg:96.30ms -step:619/1695 train_time:59607ms step_avg:96.30ms -step:620/1695 train_time:59703ms step_avg:96.29ms -step:621/1695 train_time:59798ms step_avg:96.29ms -step:622/1695 train_time:59894ms step_avg:96.29ms -step:623/1695 train_time:59991ms step_avg:96.29ms -step:624/1695 train_time:60087ms step_avg:96.29ms -step:625/1695 train_time:60183ms step_avg:96.29ms -step:625/1695 val_loss:3.6179 train_time:60276ms step_avg:96.44ms -step:626/1695 train_time:60301ms step_avg:96.33ms -step:627/1695 train_time:60381ms step_avg:96.30ms -step:628/1695 train_time:60478ms step_avg:96.30ms -step:629/1695 train_time:60574ms step_avg:96.30ms -step:630/1695 train_time:60669ms step_avg:96.30ms -step:631/1695 train_time:60764ms step_avg:96.30ms -step:632/1695 train_time:60858ms step_avg:96.29ms -step:633/1695 train_time:60953ms step_avg:96.29ms -step:634/1695 train_time:61048ms step_avg:96.29ms -step:635/1695 train_time:61143ms step_avg:96.29ms -step:636/1695 train_time:61240ms step_avg:96.29ms -step:637/1695 train_time:61337ms step_avg:96.29ms -step:638/1695 train_time:61435ms step_avg:96.29ms -step:639/1695 train_time:61532ms step_avg:96.29ms -step:640/1695 train_time:61628ms step_avg:96.29ms -step:641/1695 train_time:61725ms step_avg:96.30ms -step:642/1695 train_time:61821ms step_avg:96.29ms -step:643/1695 train_time:61915ms step_avg:96.29ms -step:644/1695 train_time:62011ms step_avg:96.29ms -step:645/1695 train_time:62106ms step_avg:96.29ms -step:646/1695 train_time:62201ms step_avg:96.29ms -step:647/1695 train_time:62297ms step_avg:96.29ms -step:648/1695 train_time:62394ms step_avg:96.29ms -step:649/1695 train_time:62492ms step_avg:96.29ms -step:650/1695 train_time:62589ms step_avg:96.29ms -step:651/1695 train_time:62685ms step_avg:96.29ms -step:652/1695 train_time:62781ms step_avg:96.29ms -step:653/1695 train_time:62876ms step_avg:96.29ms -step:654/1695 train_time:62971ms step_avg:96.29ms -step:655/1695 train_time:63067ms step_avg:96.29ms -step:656/1695 train_time:63164ms step_avg:96.29ms -step:657/1695 train_time:63261ms step_avg:96.29ms -step:658/1695 train_time:63357ms step_avg:96.29ms -step:659/1695 train_time:63453ms step_avg:96.29ms -step:660/1695 train_time:63550ms step_avg:96.29ms -step:661/1695 train_time:63647ms step_avg:96.29ms -step:662/1695 train_time:63743ms step_avg:96.29ms -step:663/1695 train_time:63838ms step_avg:96.29ms -step:664/1695 train_time:63933ms step_avg:96.28ms -step:665/1695 train_time:64029ms step_avg:96.28ms -step:666/1695 train_time:64126ms step_avg:96.28ms -step:667/1695 train_time:64222ms step_avg:96.29ms -step:668/1695 train_time:64318ms step_avg:96.28ms -step:669/1695 train_time:64413ms step_avg:96.28ms -step:670/1695 train_time:64511ms step_avg:96.28ms -step:671/1695 train_time:64608ms step_avg:96.29ms -step:672/1695 train_time:64705ms step_avg:96.29ms -step:673/1695 train_time:64801ms step_avg:96.29ms -step:674/1695 train_time:64896ms step_avg:96.28ms -step:675/1695 train_time:64993ms step_avg:96.29ms -step:676/1695 train_time:65089ms step_avg:96.29ms -step:677/1695 train_time:65186ms step_avg:96.29ms -step:678/1695 train_time:65283ms step_avg:96.29ms -step:679/1695 train_time:65378ms step_avg:96.29ms -step:680/1695 train_time:65474ms step_avg:96.29ms -step:681/1695 train_time:65571ms step_avg:96.29ms -step:682/1695 train_time:65667ms step_avg:96.29ms -step:683/1695 train_time:65763ms step_avg:96.29ms -step:684/1695 train_time:65859ms step_avg:96.28ms -step:685/1695 train_time:65954ms step_avg:96.28ms -step:686/1695 train_time:66050ms step_avg:96.28ms -step:687/1695 train_time:66146ms step_avg:96.28ms -step:688/1695 train_time:66242ms step_avg:96.28ms -step:689/1695 train_time:66338ms step_avg:96.28ms -step:690/1695 train_time:66433ms step_avg:96.28ms -step:691/1695 train_time:66794ms step_avg:96.66ms -step:692/1695 train_time:66958ms step_avg:96.76ms -step:693/1695 train_time:67053ms step_avg:96.76ms -step:694/1695 train_time:67148ms step_avg:96.76ms -step:695/1695 train_time:67244ms step_avg:96.75ms -step:696/1695 train_time:67338ms step_avg:96.75ms -step:697/1695 train_time:67433ms step_avg:96.75ms -step:698/1695 train_time:67528ms step_avg:96.75ms -step:699/1695 train_time:67624ms step_avg:96.74ms -step:700/1695 train_time:67718ms step_avg:96.74ms -step:701/1695 train_time:67820ms step_avg:96.75ms -step:702/1695 train_time:67919ms step_avg:96.75ms -step:703/1695 train_time:68017ms step_avg:96.75ms -step:704/1695 train_time:68115ms step_avg:96.75ms -step:705/1695 train_time:68211ms step_avg:96.75ms -step:706/1695 train_time:68306ms step_avg:96.75ms -step:707/1695 train_time:68400ms step_avg:96.75ms -step:708/1695 train_time:68495ms step_avg:96.74ms -step:709/1695 train_time:68590ms step_avg:96.74ms -step:710/1695 train_time:68686ms step_avg:96.74ms -step:711/1695 train_time:68784ms step_avg:96.74ms -step:712/1695 train_time:68882ms step_avg:96.74ms -step:713/1695 train_time:68978ms step_avg:96.74ms -step:714/1695 train_time:69074ms step_avg:96.74ms -step:715/1695 train_time:69170ms step_avg:96.74ms -step:716/1695 train_time:69266ms step_avg:96.74ms -step:717/1695 train_time:69361ms step_avg:96.74ms -step:718/1695 train_time:69455ms step_avg:96.73ms -step:719/1695 train_time:69551ms step_avg:96.73ms -step:720/1695 train_time:69646ms step_avg:96.73ms -step:721/1695 train_time:69742ms step_avg:96.73ms -step:722/1695 train_time:69838ms step_avg:96.73ms -step:723/1695 train_time:69935ms step_avg:96.73ms -step:724/1695 train_time:70032ms step_avg:96.73ms -step:725/1695 train_time:70129ms step_avg:96.73ms -step:726/1695 train_time:70225ms step_avg:96.73ms -step:727/1695 train_time:70322ms step_avg:96.73ms -step:728/1695 train_time:70417ms step_avg:96.73ms -step:729/1695 train_time:70512ms step_avg:96.72ms -step:730/1695 train_time:70608ms step_avg:96.72ms -step:731/1695 train_time:70704ms step_avg:96.72ms -step:732/1695 train_time:70801ms step_avg:96.72ms -step:733/1695 train_time:70896ms step_avg:96.72ms -step:734/1695 train_time:70993ms step_avg:96.72ms -step:735/1695 train_time:71090ms step_avg:96.72ms -step:736/1695 train_time:71187ms step_avg:96.72ms -step:737/1695 train_time:71284ms step_avg:96.72ms -step:738/1695 train_time:71380ms step_avg:96.72ms -step:739/1695 train_time:71475ms step_avg:96.72ms -step:740/1695 train_time:71570ms step_avg:96.72ms -step:741/1695 train_time:71667ms step_avg:96.72ms -step:742/1695 train_time:71765ms step_avg:96.72ms -step:743/1695 train_time:71861ms step_avg:96.72ms -step:744/1695 train_time:71957ms step_avg:96.72ms -step:745/1695 train_time:72053ms step_avg:96.71ms -step:746/1695 train_time:72150ms step_avg:96.72ms -step:747/1695 train_time:72247ms step_avg:96.72ms -step:748/1695 train_time:72343ms step_avg:96.72ms -step:749/1695 train_time:72439ms step_avg:96.71ms -step:750/1695 train_time:72534ms step_avg:96.71ms -step:750/1695 val_loss:3.5645 train_time:72628ms step_avg:96.84ms -step:751/1695 train_time:72654ms step_avg:96.74ms -step:752/1695 train_time:72734ms step_avg:96.72ms -step:753/1695 train_time:72834ms step_avg:96.72ms -step:754/1695 train_time:72930ms step_avg:96.72ms -step:755/1695 train_time:73026ms step_avg:96.72ms -step:756/1695 train_time:73121ms step_avg:96.72ms -step:757/1695 train_time:73215ms step_avg:96.72ms -step:758/1695 train_time:73311ms step_avg:96.72ms -step:759/1695 train_time:73407ms step_avg:96.72ms -step:760/1695 train_time:73501ms step_avg:96.71ms -step:761/1695 train_time:73598ms step_avg:96.71ms -step:762/1695 train_time:73696ms step_avg:96.71ms -step:763/1695 train_time:73795ms step_avg:96.72ms -step:764/1695 train_time:73893ms step_avg:96.72ms -step:765/1695 train_time:73990ms step_avg:96.72ms -step:766/1695 train_time:74086ms step_avg:96.72ms -step:767/1695 train_time:74182ms step_avg:96.72ms -step:768/1695 train_time:74276ms step_avg:96.71ms -step:769/1695 train_time:74372ms step_avg:96.71ms -step:770/1695 train_time:74468ms step_avg:96.71ms -step:771/1695 train_time:74564ms step_avg:96.71ms -step:772/1695 train_time:74661ms step_avg:96.71ms -step:773/1695 train_time:74758ms step_avg:96.71ms -step:774/1695 train_time:74854ms step_avg:96.71ms -step:775/1695 train_time:74951ms step_avg:96.71ms -step:776/1695 train_time:75048ms step_avg:96.71ms -step:777/1695 train_time:75143ms step_avg:96.71ms -step:778/1695 train_time:75238ms step_avg:96.71ms -step:779/1695 train_time:75333ms step_avg:96.71ms -step:780/1695 train_time:75429ms step_avg:96.70ms -step:781/1695 train_time:75526ms step_avg:96.70ms -step:782/1695 train_time:75622ms step_avg:96.70ms -step:783/1695 train_time:75718ms step_avg:96.70ms -step:784/1695 train_time:75815ms step_avg:96.70ms -step:785/1695 train_time:75911ms step_avg:96.70ms -step:786/1695 train_time:76008ms step_avg:96.70ms -step:787/1695 train_time:76104ms step_avg:96.70ms -step:788/1695 train_time:76199ms step_avg:96.70ms -step:789/1695 train_time:76295ms step_avg:96.70ms -step:790/1695 train_time:76390ms step_avg:96.70ms -step:791/1695 train_time:76486ms step_avg:96.70ms -step:792/1695 train_time:76582ms step_avg:96.69ms -step:793/1695 train_time:76678ms step_avg:96.69ms -step:794/1695 train_time:76774ms step_avg:96.69ms -step:795/1695 train_time:76871ms step_avg:96.69ms -step:796/1695 train_time:76969ms step_avg:96.69ms -step:797/1695 train_time:77067ms step_avg:96.70ms -step:798/1695 train_time:77163ms step_avg:96.70ms -step:799/1695 train_time:77259ms step_avg:96.69ms -step:800/1695 train_time:77355ms step_avg:96.69ms -step:801/1695 train_time:77450ms step_avg:96.69ms -step:802/1695 train_time:77546ms step_avg:96.69ms -step:803/1695 train_time:77642ms step_avg:96.69ms -step:804/1695 train_time:77738ms step_avg:96.69ms -step:805/1695 train_time:77834ms step_avg:96.69ms -step:806/1695 train_time:77931ms step_avg:96.69ms -step:807/1695 train_time:78029ms step_avg:96.69ms -step:808/1695 train_time:78126ms step_avg:96.69ms -step:809/1695 train_time:78222ms step_avg:96.69ms -step:810/1695 train_time:78317ms step_avg:96.69ms -step:811/1695 train_time:78413ms step_avg:96.69ms -step:812/1695 train_time:78509ms step_avg:96.69ms -step:813/1695 train_time:78604ms step_avg:96.68ms -step:814/1695 train_time:78699ms step_avg:96.68ms -step:815/1695 train_time:78795ms step_avg:96.68ms -step:816/1695 train_time:78891ms step_avg:96.68ms -step:817/1695 train_time:78988ms step_avg:96.68ms -step:818/1695 train_time:79085ms step_avg:96.68ms -step:819/1695 train_time:79182ms step_avg:96.68ms -step:820/1695 train_time:79278ms step_avg:96.68ms -step:821/1695 train_time:79373ms step_avg:96.68ms -step:822/1695 train_time:79469ms step_avg:96.68ms -step:823/1695 train_time:79565ms step_avg:96.68ms -step:824/1695 train_time:79660ms step_avg:96.67ms -step:825/1695 train_time:79755ms step_avg:96.67ms -step:826/1695 train_time:79852ms step_avg:96.67ms -step:827/1695 train_time:79949ms step_avg:96.67ms -step:828/1695 train_time:80045ms step_avg:96.67ms -step:829/1695 train_time:80142ms step_avg:96.67ms -step:830/1695 train_time:80237ms step_avg:96.67ms -step:831/1695 train_time:80333ms step_avg:96.67ms -step:832/1695 train_time:80430ms step_avg:96.67ms -step:833/1695 train_time:80527ms step_avg:96.67ms -step:834/1695 train_time:80624ms step_avg:96.67ms -step:835/1695 train_time:80719ms step_avg:96.67ms -step:836/1695 train_time:80815ms step_avg:96.67ms -step:837/1695 train_time:80911ms step_avg:96.67ms -step:838/1695 train_time:81007ms step_avg:96.67ms -step:839/1695 train_time:81103ms step_avg:96.67ms -step:840/1695 train_time:81199ms step_avg:96.67ms -step:841/1695 train_time:81295ms step_avg:96.66ms -step:842/1695 train_time:81392ms step_avg:96.66ms -step:843/1695 train_time:81488ms step_avg:96.66ms -step:844/1695 train_time:81583ms step_avg:96.66ms -step:845/1695 train_time:81678ms step_avg:96.66ms -step:846/1695 train_time:81773ms step_avg:96.66ms -step:847/1695 train_time:81869ms step_avg:96.66ms -step:848/1695 train_time:81965ms step_avg:96.66ms -step:849/1695 train_time:82062ms step_avg:96.66ms -step:850/1695 train_time:82158ms step_avg:96.66ms -step:851/1695 train_time:82254ms step_avg:96.66ms -step:852/1695 train_time:82350ms step_avg:96.66ms -step:853/1695 train_time:82447ms step_avg:96.66ms -step:854/1695 train_time:82542ms step_avg:96.65ms -step:855/1695 train_time:82637ms step_avg:96.65ms -step:856/1695 train_time:82733ms step_avg:96.65ms -step:857/1695 train_time:82829ms step_avg:96.65ms -step:858/1695 train_time:82925ms step_avg:96.65ms -step:859/1695 train_time:83021ms step_avg:96.65ms -step:860/1695 train_time:83116ms step_avg:96.65ms -step:861/1695 train_time:83212ms step_avg:96.65ms -step:862/1695 train_time:83309ms step_avg:96.65ms -step:863/1695 train_time:83635ms step_avg:96.91ms -step:864/1695 train_time:83834ms step_avg:97.03ms -step:865/1695 train_time:83928ms step_avg:97.03ms -step:866/1695 train_time:84024ms step_avg:97.03ms -step:867/1695 train_time:84118ms step_avg:97.02ms -step:868/1695 train_time:84213ms step_avg:97.02ms -step:869/1695 train_time:84308ms step_avg:97.02ms -step:870/1695 train_time:84404ms step_avg:97.02ms -step:871/1695 train_time:84498ms step_avg:97.01ms -step:872/1695 train_time:84593ms step_avg:97.01ms -step:873/1695 train_time:84695ms step_avg:97.02ms -step:874/1695 train_time:84795ms step_avg:97.02ms -step:875/1695 train_time:84893ms step_avg:97.02ms -step:875/1695 val_loss:3.5224 train_time:84986ms step_avg:97.13ms -step:876/1695 train_time:85012ms step_avg:97.05ms -step:877/1695 train_time:85093ms step_avg:97.03ms -step:878/1695 train_time:85195ms step_avg:97.03ms -step:879/1695 train_time:85293ms step_avg:97.03ms -step:880/1695 train_time:85390ms step_avg:97.03ms -step:881/1695 train_time:85485ms step_avg:97.03ms -step:882/1695 train_time:85580ms step_avg:97.03ms -step:883/1695 train_time:85675ms step_avg:97.03ms -step:884/1695 train_time:85769ms step_avg:97.02ms -step:885/1695 train_time:85864ms step_avg:97.02ms -step:886/1695 train_time:85959ms step_avg:97.02ms -step:887/1695 train_time:86057ms step_avg:97.02ms -step:888/1695 train_time:86155ms step_avg:97.02ms -step:889/1695 train_time:86253ms step_avg:97.02ms -step:890/1695 train_time:86350ms step_avg:97.02ms -step:891/1695 train_time:86447ms step_avg:97.02ms -step:892/1695 train_time:86542ms step_avg:97.02ms -step:893/1695 train_time:86637ms step_avg:97.02ms -step:894/1695 train_time:86732ms step_avg:97.02ms -step:895/1695 train_time:86827ms step_avg:97.01ms -step:896/1695 train_time:86922ms step_avg:97.01ms -step:897/1695 train_time:87018ms step_avg:97.01ms -step:898/1695 train_time:87115ms step_avg:97.01ms -step:899/1695 train_time:87213ms step_avg:97.01ms -step:900/1695 train_time:87311ms step_avg:97.01ms -step:901/1695 train_time:87409ms step_avg:97.01ms -step:902/1695 train_time:87506ms step_avg:97.01ms -step:903/1695 train_time:87601ms step_avg:97.01ms -step:904/1695 train_time:87696ms step_avg:97.01ms -step:905/1695 train_time:87792ms step_avg:97.01ms -step:906/1695 train_time:87887ms step_avg:97.01ms -step:907/1695 train_time:87983ms step_avg:97.00ms -step:908/1695 train_time:88079ms step_avg:97.00ms -step:909/1695 train_time:88175ms step_avg:97.00ms -step:910/1695 train_time:88271ms step_avg:97.00ms -step:911/1695 train_time:88367ms step_avg:97.00ms -step:912/1695 train_time:88463ms step_avg:97.00ms -step:913/1695 train_time:88558ms step_avg:97.00ms -step:914/1695 train_time:88654ms step_avg:97.00ms -step:915/1695 train_time:88750ms step_avg:96.99ms -step:916/1695 train_time:88846ms step_avg:96.99ms -step:917/1695 train_time:88942ms step_avg:96.99ms -step:918/1695 train_time:89038ms step_avg:96.99ms -step:919/1695 train_time:89134ms step_avg:96.99ms -step:920/1695 train_time:89231ms step_avg:96.99ms -step:921/1695 train_time:89327ms step_avg:96.99ms -step:922/1695 train_time:89422ms step_avg:96.99ms -step:923/1695 train_time:89518ms step_avg:96.99ms -step:924/1695 train_time:89614ms step_avg:96.98ms -step:925/1695 train_time:89709ms step_avg:96.98ms -step:926/1695 train_time:89806ms step_avg:96.98ms -step:927/1695 train_time:89901ms step_avg:96.98ms -step:928/1695 train_time:89997ms step_avg:96.98ms -step:929/1695 train_time:90093ms step_avg:96.98ms -step:930/1695 train_time:90190ms step_avg:96.98ms -step:931/1695 train_time:90287ms step_avg:96.98ms -step:932/1695 train_time:90384ms step_avg:96.98ms -step:933/1695 train_time:90479ms step_avg:96.98ms -step:934/1695 train_time:90575ms step_avg:96.98ms -step:935/1695 train_time:90671ms step_avg:96.97ms -step:936/1695 train_time:90767ms step_avg:96.97ms -step:937/1695 train_time:90863ms step_avg:96.97ms -step:938/1695 train_time:90958ms step_avg:96.97ms -step:939/1695 train_time:91055ms step_avg:96.97ms -step:940/1695 train_time:91150ms step_avg:96.97ms -step:941/1695 train_time:91246ms step_avg:96.97ms -step:942/1695 train_time:91342ms step_avg:96.97ms -step:943/1695 train_time:91438ms step_avg:96.96ms -step:944/1695 train_time:91533ms step_avg:96.96ms -step:945/1695 train_time:91630ms step_avg:96.96ms -step:946/1695 train_time:91727ms step_avg:96.96ms -step:947/1695 train_time:91824ms step_avg:96.96ms -step:948/1695 train_time:91920ms step_avg:96.96ms -step:949/1695 train_time:92015ms step_avg:96.96ms -step:950/1695 train_time:92112ms step_avg:96.96ms -step:951/1695 train_time:92208ms step_avg:96.96ms -step:952/1695 train_time:92305ms step_avg:96.96ms -step:953/1695 train_time:92401ms step_avg:96.96ms -step:954/1695 train_time:92496ms step_avg:96.96ms -step:955/1695 train_time:92592ms step_avg:96.96ms -step:956/1695 train_time:92689ms step_avg:96.95ms -step:957/1695 train_time:92785ms step_avg:96.95ms -step:958/1695 train_time:92881ms step_avg:96.95ms -step:959/1695 train_time:92976ms step_avg:96.95ms -step:960/1695 train_time:93072ms step_avg:96.95ms -step:961/1695 train_time:93168ms step_avg:96.95ms -step:962/1695 train_time:93265ms step_avg:96.95ms -step:963/1695 train_time:93362ms step_avg:96.95ms -step:964/1695 train_time:93458ms step_avg:96.95ms -step:965/1695 train_time:93555ms step_avg:96.95ms -step:966/1695 train_time:93652ms step_avg:96.95ms -step:967/1695 train_time:93748ms step_avg:96.95ms -step:968/1695 train_time:93844ms step_avg:96.95ms -step:969/1695 train_time:93939ms step_avg:96.94ms -step:970/1695 train_time:94035ms step_avg:96.94ms -step:971/1695 train_time:94132ms step_avg:96.94ms -step:972/1695 train_time:94229ms step_avg:96.94ms -step:973/1695 train_time:94326ms step_avg:96.94ms -step:974/1695 train_time:94422ms step_avg:96.94ms -step:975/1695 train_time:94517ms step_avg:96.94ms -step:976/1695 train_time:94614ms step_avg:96.94ms -step:977/1695 train_time:94712ms step_avg:96.94ms -step:978/1695 train_time:94808ms step_avg:96.94ms -step:979/1695 train_time:94904ms step_avg:96.94ms -step:980/1695 train_time:95000ms step_avg:96.94ms -step:981/1695 train_time:95095ms step_avg:96.94ms -step:982/1695 train_time:95192ms step_avg:96.94ms -step:983/1695 train_time:95289ms step_avg:96.94ms -step:984/1695 train_time:95386ms step_avg:96.94ms -step:985/1695 train_time:95482ms step_avg:96.94ms -step:986/1695 train_time:95577ms step_avg:96.93ms -step:987/1695 train_time:95673ms step_avg:96.93ms -step:988/1695 train_time:95770ms step_avg:96.93ms -step:989/1695 train_time:95866ms step_avg:96.93ms -step:990/1695 train_time:95963ms step_avg:96.93ms -step:991/1695 train_time:96059ms step_avg:96.93ms -step:992/1695 train_time:96154ms step_avg:96.93ms -step:993/1695 train_time:96251ms step_avg:96.93ms -step:994/1695 train_time:96348ms step_avg:96.93ms -step:995/1695 train_time:96445ms step_avg:96.93ms -step:996/1695 train_time:96539ms step_avg:96.93ms -step:997/1695 train_time:96634ms step_avg:96.93ms -step:998/1695 train_time:96732ms step_avg:96.93ms -step:999/1695 train_time:96829ms step_avg:96.93ms -step:1000/1695 train_time:96926ms step_avg:96.93ms -step:1000/1695 val_loss:3.4843 train_time:97020ms step_avg:97.02ms -step:1001/1695 train_time:97046ms step_avg:96.95ms -step:1002/1695 train_time:97126ms step_avg:96.93ms -step:1003/1695 train_time:97224ms step_avg:96.93ms -step:1004/1695 train_time:97319ms step_avg:96.93ms -step:1005/1695 train_time:97415ms step_avg:96.93ms -step:1006/1695 train_time:97511ms step_avg:96.93ms -step:1007/1695 train_time:97607ms step_avg:96.93ms -step:1008/1695 train_time:97701ms step_avg:96.93ms -step:1009/1695 train_time:97796ms step_avg:96.92ms -step:1010/1695 train_time:97892ms step_avg:96.92ms -step:1011/1695 train_time:97989ms step_avg:96.92ms -step:1012/1695 train_time:98087ms step_avg:96.92ms -step:1013/1695 train_time:98185ms step_avg:96.93ms -step:1014/1695 train_time:98281ms step_avg:96.92ms -step:1015/1695 train_time:98377ms step_avg:96.92ms -step:1016/1695 train_time:98474ms step_avg:96.92ms -step:1017/1695 train_time:98570ms step_avg:96.92ms -step:1018/1695 train_time:98665ms step_avg:96.92ms -step:1019/1695 train_time:98760ms step_avg:96.92ms -step:1020/1695 train_time:98855ms step_avg:96.92ms -step:1021/1695 train_time:98951ms step_avg:96.92ms -step:1022/1695 train_time:99049ms step_avg:96.92ms -step:1023/1695 train_time:99145ms step_avg:96.92ms -step:1024/1695 train_time:99241ms step_avg:96.92ms -step:1025/1695 train_time:99337ms step_avg:96.91ms -step:1026/1695 train_time:99434ms step_avg:96.91ms -step:1027/1695 train_time:99531ms step_avg:96.91ms -step:1028/1695 train_time:99627ms step_avg:96.91ms -step:1029/1695 train_time:99724ms step_avg:96.91ms -step:1030/1695 train_time:99817ms step_avg:96.91ms -step:1031/1695 train_time:99913ms step_avg:96.91ms -step:1032/1695 train_time:100009ms step_avg:96.91ms -step:1033/1695 train_time:100105ms step_avg:96.91ms -step:1034/1695 train_time:100202ms step_avg:96.91ms -step:1035/1695 train_time:100298ms step_avg:96.91ms -step:1036/1695 train_time:100628ms step_avg:97.13ms -step:1037/1695 train_time:100810ms step_avg:97.21ms -step:1038/1695 train_time:100904ms step_avg:97.21ms -step:1039/1695 train_time:100998ms step_avg:97.21ms -step:1040/1695 train_time:101093ms step_avg:97.20ms -step:1041/1695 train_time:101188ms step_avg:97.20ms -step:1042/1695 train_time:101283ms step_avg:97.20ms -step:1043/1695 train_time:101377ms step_avg:97.20ms -step:1044/1695 train_time:101472ms step_avg:97.20ms -step:1045/1695 train_time:101567ms step_avg:97.19ms -step:1046/1695 train_time:101666ms step_avg:97.20ms -step:1047/1695 train_time:101765ms step_avg:97.20ms -step:1048/1695 train_time:101862ms step_avg:97.20ms -step:1049/1695 train_time:101959ms step_avg:97.20ms -step:1050/1695 train_time:102054ms step_avg:97.19ms -step:1051/1695 train_time:102149ms step_avg:97.19ms -step:1052/1695 train_time:102244ms step_avg:97.19ms -step:1053/1695 train_time:102339ms step_avg:97.19ms -step:1054/1695 train_time:102434ms step_avg:97.19ms -step:1055/1695 train_time:102529ms step_avg:97.18ms -step:1056/1695 train_time:102626ms step_avg:97.18ms -step:1057/1695 train_time:102724ms step_avg:97.18ms -step:1058/1695 train_time:102821ms step_avg:97.18ms -step:1059/1695 train_time:102917ms step_avg:97.18ms -step:1060/1695 train_time:103015ms step_avg:97.18ms -step:1061/1695 train_time:103111ms step_avg:97.18ms -step:1062/1695 train_time:103207ms step_avg:97.18ms -step:1063/1695 train_time:103302ms step_avg:97.18ms -step:1064/1695 train_time:103397ms step_avg:97.18ms -step:1065/1695 train_time:103493ms step_avg:97.18ms -step:1066/1695 train_time:103589ms step_avg:97.18ms -step:1067/1695 train_time:103684ms step_avg:97.17ms -step:1068/1695 train_time:103781ms step_avg:97.17ms -step:1069/1695 train_time:103877ms step_avg:97.17ms -step:1070/1695 train_time:103974ms step_avg:97.17ms -step:1071/1695 train_time:104070ms step_avg:97.17ms -step:1072/1695 train_time:104166ms step_avg:97.17ms -step:1073/1695 train_time:104262ms step_avg:97.17ms -step:1074/1695 train_time:104357ms step_avg:97.17ms -step:1075/1695 train_time:104453ms step_avg:97.17ms -step:1076/1695 train_time:104549ms step_avg:97.16ms -step:1077/1695 train_time:104644ms step_avg:97.16ms -step:1078/1695 train_time:104739ms step_avg:97.16ms -step:1079/1695 train_time:104836ms step_avg:97.16ms -step:1080/1695 train_time:104933ms step_avg:97.16ms -step:1081/1695 train_time:105029ms step_avg:97.16ms -step:1082/1695 train_time:105125ms step_avg:97.16ms -step:1083/1695 train_time:105221ms step_avg:97.16ms -step:1084/1695 train_time:105316ms step_avg:97.15ms -step:1085/1695 train_time:105412ms step_avg:97.15ms -step:1086/1695 train_time:105508ms step_avg:97.15ms -step:1087/1695 train_time:105604ms step_avg:97.15ms -step:1088/1695 train_time:105699ms step_avg:97.15ms -step:1089/1695 train_time:105795ms step_avg:97.15ms -step:1090/1695 train_time:105893ms step_avg:97.15ms -step:1091/1695 train_time:105990ms step_avg:97.15ms -step:1092/1695 train_time:106086ms step_avg:97.15ms -step:1093/1695 train_time:106181ms step_avg:97.15ms -step:1094/1695 train_time:106277ms step_avg:97.15ms -step:1095/1695 train_time:106373ms step_avg:97.14ms -step:1096/1695 train_time:106469ms step_avg:97.14ms -step:1097/1695 train_time:106565ms step_avg:97.14ms -step:1098/1695 train_time:106661ms step_avg:97.14ms -step:1099/1695 train_time:106756ms step_avg:97.14ms -step:1100/1695 train_time:106854ms step_avg:97.14ms -step:1101/1695 train_time:106950ms step_avg:97.14ms -step:1102/1695 train_time:107046ms step_avg:97.14ms -step:1103/1695 train_time:107142ms step_avg:97.14ms -step:1104/1695 train_time:107237ms step_avg:97.13ms -step:1105/1695 train_time:107333ms step_avg:97.13ms -step:1106/1695 train_time:107429ms step_avg:97.13ms -step:1107/1695 train_time:107526ms step_avg:97.13ms -step:1108/1695 train_time:107622ms step_avg:97.13ms -step:1109/1695 train_time:107718ms step_avg:97.13ms -step:1110/1695 train_time:107814ms step_avg:97.13ms -step:1111/1695 train_time:107912ms step_avg:97.13ms -step:1112/1695 train_time:108009ms step_avg:97.13ms -step:1113/1695 train_time:108105ms step_avg:97.13ms -step:1114/1695 train_time:108200ms step_avg:97.13ms -step:1115/1695 train_time:108296ms step_avg:97.13ms -step:1116/1695 train_time:108393ms step_avg:97.13ms -step:1117/1695 train_time:108490ms step_avg:97.13ms -step:1118/1695 train_time:108587ms step_avg:97.13ms -step:1119/1695 train_time:108683ms step_avg:97.13ms -step:1120/1695 train_time:108778ms step_avg:97.12ms -step:1121/1695 train_time:108875ms step_avg:97.12ms -step:1122/1695 train_time:108970ms step_avg:97.12ms -step:1123/1695 train_time:109068ms step_avg:97.12ms -step:1124/1695 train_time:109164ms step_avg:97.12ms -step:1125/1695 train_time:109260ms step_avg:97.12ms -step:1125/1695 val_loss:3.4352 train_time:109353ms step_avg:97.20ms -step:1126/1695 train_time:109379ms step_avg:97.14ms -step:1127/1695 train_time:109456ms step_avg:97.12ms -step:1128/1695 train_time:109554ms step_avg:97.12ms -step:1129/1695 train_time:109650ms step_avg:97.12ms -step:1130/1695 train_time:109745ms step_avg:97.12ms -step:1131/1695 train_time:109840ms step_avg:97.12ms -step:1132/1695 train_time:109934ms step_avg:97.12ms -step:1133/1695 train_time:110031ms step_avg:97.11ms -step:1134/1695 train_time:110129ms step_avg:97.12ms -step:1135/1695 train_time:110228ms step_avg:97.12ms -step:1136/1695 train_time:110328ms step_avg:97.12ms -step:1137/1695 train_time:110431ms step_avg:97.12ms -step:1138/1695 train_time:110532ms step_avg:97.13ms -step:1139/1695 train_time:110631ms step_avg:97.13ms -step:1140/1695 train_time:110729ms step_avg:97.13ms -step:1141/1695 train_time:110826ms step_avg:97.13ms -step:1142/1695 train_time:110923ms step_avg:97.13ms -step:1143/1695 train_time:111020ms step_avg:97.13ms -step:1144/1695 train_time:111116ms step_avg:97.13ms -step:1145/1695 train_time:111214ms step_avg:97.13ms -step:1146/1695 train_time:111312ms step_avg:97.13ms -step:1147/1695 train_time:111412ms step_avg:97.13ms -step:1148/1695 train_time:111511ms step_avg:97.14ms -step:1149/1695 train_time:111611ms step_avg:97.14ms -step:1150/1695 train_time:111710ms step_avg:97.14ms -step:1151/1695 train_time:111808ms step_avg:97.14ms -step:1152/1695 train_time:111907ms step_avg:97.14ms -step:1153/1695 train_time:112006ms step_avg:97.14ms -step:1154/1695 train_time:112104ms step_avg:97.14ms -step:1155/1695 train_time:112201ms step_avg:97.14ms -step:1156/1695 train_time:112299ms step_avg:97.14ms -step:1157/1695 train_time:112396ms step_avg:97.14ms -step:1158/1695 train_time:112495ms step_avg:97.15ms -step:1159/1695 train_time:112593ms step_avg:97.15ms -step:1160/1695 train_time:112691ms step_avg:97.15ms -step:1161/1695 train_time:112790ms step_avg:97.15ms -step:1162/1695 train_time:112888ms step_avg:97.15ms -step:1163/1695 train_time:112986ms step_avg:97.15ms -step:1164/1695 train_time:113084ms step_avg:97.15ms -step:1165/1695 train_time:113181ms step_avg:97.15ms -step:1166/1695 train_time:113279ms step_avg:97.15ms -step:1167/1695 train_time:113376ms step_avg:97.15ms -step:1168/1695 train_time:113473ms step_avg:97.15ms -step:1169/1695 train_time:113571ms step_avg:97.15ms -step:1170/1695 train_time:113669ms step_avg:97.15ms -step:1171/1695 train_time:113767ms step_avg:97.15ms -step:1172/1695 train_time:113864ms step_avg:97.15ms -step:1173/1695 train_time:113962ms step_avg:97.15ms -step:1174/1695 train_time:114060ms step_avg:97.15ms -step:1175/1695 train_time:114157ms step_avg:97.16ms -step:1176/1695 train_time:114255ms step_avg:97.16ms -step:1177/1695 train_time:114352ms step_avg:97.16ms -step:1178/1695 train_time:114451ms step_avg:97.16ms -step:1179/1695 train_time:114550ms step_avg:97.16ms -step:1180/1695 train_time:114648ms step_avg:97.16ms -step:1181/1695 train_time:114747ms step_avg:97.16ms -step:1182/1695 train_time:114844ms step_avg:97.16ms -step:1183/1695 train_time:114942ms step_avg:97.16ms -step:1184/1695 train_time:115039ms step_avg:97.16ms -step:1185/1695 train_time:115136ms step_avg:97.16ms -step:1186/1695 train_time:115234ms step_avg:97.16ms -step:1187/1695 train_time:115331ms step_avg:97.16ms -step:1188/1695 train_time:115430ms step_avg:97.16ms -step:1189/1695 train_time:115529ms step_avg:97.16ms -step:1190/1695 train_time:115627ms step_avg:97.17ms -step:1191/1695 train_time:115725ms step_avg:97.17ms -step:1192/1695 train_time:115823ms step_avg:97.17ms -step:1193/1695 train_time:115922ms step_avg:97.17ms -step:1194/1695 train_time:116019ms step_avg:97.17ms -step:1195/1695 train_time:116117ms step_avg:97.17ms -step:1196/1695 train_time:116214ms step_avg:97.17ms -step:1197/1695 train_time:116311ms step_avg:97.17ms -step:1198/1695 train_time:116409ms step_avg:97.17ms -step:1199/1695 train_time:116507ms step_avg:97.17ms -step:1200/1695 train_time:116604ms step_avg:97.17ms -step:1201/1695 train_time:116702ms step_avg:97.17ms -step:1202/1695 train_time:116799ms step_avg:97.17ms -step:1203/1695 train_time:116897ms step_avg:97.17ms -step:1204/1695 train_time:116995ms step_avg:97.17ms -step:1205/1695 train_time:117093ms step_avg:97.17ms -step:1206/1695 train_time:117191ms step_avg:97.17ms -step:1207/1695 train_time:117289ms step_avg:97.17ms -step:1208/1695 train_time:117624ms step_avg:97.37ms -step:1209/1695 train_time:117814ms step_avg:97.45ms -step:1210/1695 train_time:117909ms step_avg:97.45ms -step:1211/1695 train_time:118006ms step_avg:97.44ms -step:1212/1695 train_time:118103ms step_avg:97.44ms -step:1213/1695 train_time:118199ms step_avg:97.44ms -step:1214/1695 train_time:118295ms step_avg:97.44ms -step:1215/1695 train_time:118393ms step_avg:97.44ms -step:1216/1695 train_time:118490ms step_avg:97.44ms -step:1217/1695 train_time:118587ms step_avg:97.44ms -step:1218/1695 train_time:118689ms step_avg:97.45ms -step:1219/1695 train_time:118792ms step_avg:97.45ms -step:1220/1695 train_time:118892ms step_avg:97.45ms -step:1221/1695 train_time:118991ms step_avg:97.45ms -step:1222/1695 train_time:119090ms step_avg:97.46ms -step:1223/1695 train_time:119190ms step_avg:97.46ms -step:1224/1695 train_time:119287ms step_avg:97.46ms -step:1225/1695 train_time:119384ms step_avg:97.46ms -step:1226/1695 train_time:119481ms step_avg:97.46ms -step:1227/1695 train_time:119577ms step_avg:97.45ms -step:1228/1695 train_time:119673ms step_avg:97.45ms -step:1229/1695 train_time:119773ms step_avg:97.46ms -step:1230/1695 train_time:119873ms step_avg:97.46ms -step:1231/1695 train_time:119972ms step_avg:97.46ms -step:1232/1695 train_time:120070ms step_avg:97.46ms -step:1233/1695 train_time:120168ms step_avg:97.46ms -step:1234/1695 train_time:120268ms step_avg:97.46ms -step:1235/1695 train_time:120366ms step_avg:97.46ms -step:1236/1695 train_time:120464ms step_avg:97.46ms -step:1237/1695 train_time:120561ms step_avg:97.46ms -step:1238/1695 train_time:120658ms step_avg:97.46ms -step:1239/1695 train_time:120756ms step_avg:97.46ms -step:1240/1695 train_time:120853ms step_avg:97.46ms -step:1241/1695 train_time:120952ms step_avg:97.46ms -step:1242/1695 train_time:121051ms step_avg:97.46ms -step:1243/1695 train_time:121150ms step_avg:97.47ms -step:1244/1695 train_time:121249ms step_avg:97.47ms -step:1245/1695 train_time:121347ms step_avg:97.47ms -step:1246/1695 train_time:121444ms step_avg:97.47ms -step:1247/1695 train_time:121542ms step_avg:97.47ms -step:1248/1695 train_time:121640ms step_avg:97.47ms -step:1249/1695 train_time:121738ms step_avg:97.47ms -step:1250/1695 train_time:121835ms step_avg:97.47ms -step:1250/1695 val_loss:3.3886 train_time:121930ms step_avg:97.54ms -step:1251/1695 train_time:121956ms step_avg:97.49ms -step:1252/1695 train_time:122037ms step_avg:97.47ms -step:1253/1695 train_time:122135ms step_avg:97.47ms -step:1254/1695 train_time:122231ms step_avg:97.47ms -step:1255/1695 train_time:122327ms step_avg:97.47ms -step:1256/1695 train_time:122424ms step_avg:97.47ms -step:1257/1695 train_time:122520ms step_avg:97.47ms -step:1258/1695 train_time:122616ms step_avg:97.47ms -step:1259/1695 train_time:122713ms step_avg:97.47ms -step:1260/1695 train_time:122809ms step_avg:97.47ms -step:1261/1695 train_time:122913ms step_avg:97.47ms -step:1262/1695 train_time:123013ms step_avg:97.47ms -step:1263/1695 train_time:123111ms step_avg:97.48ms -step:1264/1695 train_time:123209ms step_avg:97.48ms -step:1265/1695 train_time:123306ms step_avg:97.48ms -step:1266/1695 train_time:123403ms step_avg:97.47ms -step:1267/1695 train_time:123499ms step_avg:97.47ms -step:1268/1695 train_time:123596ms step_avg:97.47ms -step:1269/1695 train_time:123693ms step_avg:97.47ms -step:1270/1695 train_time:123789ms step_avg:97.47ms -step:1271/1695 train_time:123889ms step_avg:97.47ms -step:1272/1695 train_time:123988ms step_avg:97.47ms -step:1273/1695 train_time:124086ms step_avg:97.48ms -step:1274/1695 train_time:124185ms step_avg:97.48ms -step:1275/1695 train_time:124284ms step_avg:97.48ms -step:1276/1695 train_time:124381ms step_avg:97.48ms -step:1277/1695 train_time:124479ms step_avg:97.48ms -step:1278/1695 train_time:124576ms step_avg:97.48ms -step:1279/1695 train_time:124673ms step_avg:97.48ms -step:1280/1695 train_time:124770ms step_avg:97.48ms -step:1281/1695 train_time:124868ms step_avg:97.48ms -step:1282/1695 train_time:124966ms step_avg:97.48ms -step:1283/1695 train_time:125065ms step_avg:97.48ms -step:1284/1695 train_time:125165ms step_avg:97.48ms -step:1285/1695 train_time:125264ms step_avg:97.48ms -step:1286/1695 train_time:125363ms step_avg:97.48ms -step:1287/1695 train_time:125460ms step_avg:97.48ms -step:1288/1695 train_time:125558ms step_avg:97.48ms -step:1289/1695 train_time:125656ms step_avg:97.48ms -step:1290/1695 train_time:125755ms step_avg:97.48ms -step:1291/1695 train_time:125853ms step_avg:97.48ms -step:1292/1695 train_time:125950ms step_avg:97.48ms -step:1293/1695 train_time:126048ms step_avg:97.49ms -step:1294/1695 train_time:126147ms step_avg:97.49ms -step:1295/1695 train_time:126245ms step_avg:97.49ms -step:1296/1695 train_time:126344ms step_avg:97.49ms -step:1297/1695 train_time:126440ms step_avg:97.49ms -step:1298/1695 train_time:126538ms step_avg:97.49ms -step:1299/1695 train_time:126635ms step_avg:97.49ms -step:1300/1695 train_time:126731ms step_avg:97.49ms -step:1301/1695 train_time:126829ms step_avg:97.49ms -step:1302/1695 train_time:126927ms step_avg:97.49ms -step:1303/1695 train_time:127025ms step_avg:97.49ms -step:1304/1695 train_time:127124ms step_avg:97.49ms -step:1305/1695 train_time:127222ms step_avg:97.49ms -step:1306/1695 train_time:127320ms step_avg:97.49ms -step:1307/1695 train_time:127418ms step_avg:97.49ms -step:1308/1695 train_time:127515ms step_avg:97.49ms -step:1309/1695 train_time:127612ms step_avg:97.49ms -step:1310/1695 train_time:127710ms step_avg:97.49ms -step:1311/1695 train_time:127807ms step_avg:97.49ms -step:1312/1695 train_time:127905ms step_avg:97.49ms -step:1313/1695 train_time:128004ms step_avg:97.49ms -step:1314/1695 train_time:128104ms step_avg:97.49ms -step:1315/1695 train_time:128203ms step_avg:97.49ms -step:1316/1695 train_time:128301ms step_avg:97.49ms -step:1317/1695 train_time:128399ms step_avg:97.49ms -step:1318/1695 train_time:128498ms step_avg:97.49ms -step:1319/1695 train_time:128596ms step_avg:97.50ms -step:1320/1695 train_time:128695ms step_avg:97.50ms -step:1321/1695 train_time:128792ms step_avg:97.50ms -step:1322/1695 train_time:128889ms step_avg:97.50ms -step:1323/1695 train_time:128986ms step_avg:97.50ms -step:1324/1695 train_time:129085ms step_avg:97.50ms -step:1325/1695 train_time:129184ms step_avg:97.50ms -step:1326/1695 train_time:129282ms step_avg:97.50ms -step:1327/1695 train_time:129379ms step_avg:97.50ms -step:1328/1695 train_time:129477ms step_avg:97.50ms -step:1329/1695 train_time:129574ms step_avg:97.50ms -step:1330/1695 train_time:129672ms step_avg:97.50ms -step:1331/1695 train_time:129769ms step_avg:97.50ms -step:1332/1695 train_time:129866ms step_avg:97.50ms -step:1333/1695 train_time:129964ms step_avg:97.50ms -step:1334/1695 train_time:130063ms step_avg:97.50ms -step:1335/1695 train_time:130162ms step_avg:97.50ms -step:1336/1695 train_time:130260ms step_avg:97.50ms -step:1337/1695 train_time:130358ms step_avg:97.50ms -step:1338/1695 train_time:130455ms step_avg:97.50ms -step:1339/1695 train_time:130553ms step_avg:97.50ms -step:1340/1695 train_time:130652ms step_avg:97.50ms -step:1341/1695 train_time:130749ms step_avg:97.50ms -step:1342/1695 train_time:130846ms step_avg:97.50ms -step:1343/1695 train_time:130944ms step_avg:97.50ms -step:1344/1695 train_time:131043ms step_avg:97.50ms -step:1345/1695 train_time:131142ms step_avg:97.50ms -step:1346/1695 train_time:131241ms step_avg:97.50ms -step:1347/1695 train_time:131340ms step_avg:97.51ms -step:1348/1695 train_time:131438ms step_avg:97.51ms -step:1349/1695 train_time:131537ms step_avg:97.51ms -step:1350/1695 train_time:131636ms step_avg:97.51ms -step:1351/1695 train_time:131734ms step_avg:97.51ms -step:1352/1695 train_time:131832ms step_avg:97.51ms -step:1353/1695 train_time:131930ms step_avg:97.51ms -step:1354/1695 train_time:132028ms step_avg:97.51ms -step:1355/1695 train_time:132126ms step_avg:97.51ms -step:1356/1695 train_time:132223ms step_avg:97.51ms -step:1357/1695 train_time:132321ms step_avg:97.51ms -step:1358/1695 train_time:132419ms step_avg:97.51ms -step:1359/1695 train_time:132517ms step_avg:97.51ms -step:1360/1695 train_time:132614ms step_avg:97.51ms -step:1361/1695 train_time:132711ms step_avg:97.51ms -step:1362/1695 train_time:132808ms step_avg:97.51ms -step:1363/1695 train_time:132905ms step_avg:97.51ms -step:1364/1695 train_time:133004ms step_avg:97.51ms -step:1365/1695 train_time:133102ms step_avg:97.51ms -step:1366/1695 train_time:133200ms step_avg:97.51ms -step:1367/1695 train_time:133297ms step_avg:97.51ms -step:1368/1695 train_time:133393ms step_avg:97.51ms -step:1369/1695 train_time:133491ms step_avg:97.51ms -step:1370/1695 train_time:133589ms step_avg:97.51ms -step:1371/1695 train_time:133687ms step_avg:97.51ms -step:1372/1695 train_time:133785ms step_avg:97.51ms -step:1373/1695 train_time:133884ms step_avg:97.51ms -step:1374/1695 train_time:133982ms step_avg:97.51ms -step:1375/1695 train_time:134080ms step_avg:97.51ms -step:1375/1695 val_loss:3.3495 train_time:134174ms step_avg:97.58ms -step:1376/1695 train_time:134203ms step_avg:97.53ms -step:1377/1695 train_time:134283ms step_avg:97.52ms -step:1378/1695 train_time:134384ms step_avg:97.52ms -step:1379/1695 train_time:134483ms step_avg:97.52ms -step:1380/1695 train_time:134581ms step_avg:97.52ms -step:1381/1695 train_time:134941ms step_avg:97.71ms -step:1382/1695 train_time:135109ms step_avg:97.76ms -step:1383/1695 train_time:135205ms step_avg:97.76ms -step:1384/1695 train_time:135302ms step_avg:97.76ms -step:1385/1695 train_time:135398ms step_avg:97.76ms -step:1386/1695 train_time:135494ms step_avg:97.76ms -step:1387/1695 train_time:135591ms step_avg:97.76ms -step:1388/1695 train_time:135686ms step_avg:97.76ms -step:1389/1695 train_time:135783ms step_avg:97.76ms -step:1390/1695 train_time:135880ms step_avg:97.76ms -step:1391/1695 train_time:135981ms step_avg:97.76ms -step:1392/1695 train_time:136087ms step_avg:97.76ms -step:1393/1695 train_time:136186ms step_avg:97.76ms -step:1394/1695 train_time:136284ms step_avg:97.76ms -step:1395/1695 train_time:136381ms step_avg:97.76ms -step:1396/1695 train_time:136480ms step_avg:97.76ms -step:1397/1695 train_time:136577ms step_avg:97.76ms -step:1398/1695 train_time:136674ms step_avg:97.76ms -step:1399/1695 train_time:136770ms step_avg:97.76ms -step:1400/1695 train_time:136866ms step_avg:97.76ms -step:1401/1695 train_time:136964ms step_avg:97.76ms -step:1402/1695 train_time:137065ms step_avg:97.76ms -step:1403/1695 train_time:137165ms step_avg:97.77ms -step:1404/1695 train_time:137264ms step_avg:97.77ms -step:1405/1695 train_time:137362ms step_avg:97.77ms -step:1406/1695 train_time:137460ms step_avg:97.77ms -step:1407/1695 train_time:137558ms step_avg:97.77ms -step:1408/1695 train_time:137656ms step_avg:97.77ms -step:1409/1695 train_time:137752ms step_avg:97.77ms -step:1410/1695 train_time:137849ms step_avg:97.77ms -step:1411/1695 train_time:137946ms step_avg:97.76ms -step:1412/1695 train_time:138046ms step_avg:97.77ms -step:1413/1695 train_time:138144ms step_avg:97.77ms -step:1414/1695 train_time:138244ms step_avg:97.77ms -step:1415/1695 train_time:138342ms step_avg:97.77ms -step:1416/1695 train_time:138440ms step_avg:97.77ms -step:1417/1695 train_time:138538ms step_avg:97.77ms -step:1418/1695 train_time:138637ms step_avg:97.77ms -step:1419/1695 train_time:138735ms step_avg:97.77ms -step:1420/1695 train_time:138832ms step_avg:97.77ms -step:1421/1695 train_time:138928ms step_avg:97.77ms -step:1422/1695 train_time:139025ms step_avg:97.77ms -step:1423/1695 train_time:139123ms step_avg:97.77ms -step:1424/1695 train_time:139221ms step_avg:97.77ms -step:1425/1695 train_time:139320ms step_avg:97.77ms -step:1426/1695 train_time:139418ms step_avg:97.77ms -step:1427/1695 train_time:139515ms step_avg:97.77ms -step:1428/1695 train_time:139612ms step_avg:97.77ms -step:1429/1695 train_time:139709ms step_avg:97.77ms -step:1430/1695 train_time:139807ms step_avg:97.77ms -step:1431/1695 train_time:139905ms step_avg:97.77ms -step:1432/1695 train_time:140004ms step_avg:97.77ms -step:1433/1695 train_time:140104ms step_avg:97.77ms -step:1434/1695 train_time:140202ms step_avg:97.77ms -step:1435/1695 train_time:140300ms step_avg:97.77ms -step:1436/1695 train_time:140397ms step_avg:97.77ms -step:1437/1695 train_time:140496ms step_avg:97.77ms -step:1438/1695 train_time:140594ms step_avg:97.77ms -step:1439/1695 train_time:140691ms step_avg:97.77ms -step:1440/1695 train_time:140788ms step_avg:97.77ms -step:1441/1695 train_time:140885ms step_avg:97.77ms -step:1442/1695 train_time:140982ms step_avg:97.77ms -step:1443/1695 train_time:141080ms step_avg:97.77ms -step:1444/1695 train_time:141178ms step_avg:97.77ms -step:1445/1695 train_time:141275ms step_avg:97.77ms -step:1446/1695 train_time:141373ms step_avg:97.77ms -step:1447/1695 train_time:141471ms step_avg:97.77ms -step:1448/1695 train_time:141569ms step_avg:97.77ms -step:1449/1695 train_time:141667ms step_avg:97.77ms -step:1450/1695 train_time:141765ms step_avg:97.77ms -step:1451/1695 train_time:141864ms step_avg:97.77ms -step:1452/1695 train_time:141962ms step_avg:97.77ms -step:1453/1695 train_time:142060ms step_avg:97.77ms -step:1454/1695 train_time:142159ms step_avg:97.77ms -step:1455/1695 train_time:142257ms step_avg:97.77ms -step:1456/1695 train_time:142355ms step_avg:97.77ms -step:1457/1695 train_time:142452ms step_avg:97.77ms -step:1458/1695 train_time:142550ms step_avg:97.77ms -step:1459/1695 train_time:142648ms step_avg:97.77ms -step:1460/1695 train_time:142746ms step_avg:97.77ms -step:1461/1695 train_time:142844ms step_avg:97.77ms -step:1462/1695 train_time:142943ms step_avg:97.77ms -step:1463/1695 train_time:143041ms step_avg:97.77ms -step:1464/1695 train_time:143141ms step_avg:97.77ms -step:1465/1695 train_time:143241ms step_avg:97.78ms -step:1466/1695 train_time:143339ms step_avg:97.78ms -step:1467/1695 train_time:143438ms step_avg:97.78ms -step:1468/1695 train_time:143537ms step_avg:97.78ms -step:1469/1695 train_time:143635ms step_avg:97.78ms -step:1470/1695 train_time:143732ms step_avg:97.78ms -step:1471/1695 train_time:143829ms step_avg:97.78ms -step:1472/1695 train_time:143927ms step_avg:97.78ms -step:1473/1695 train_time:144025ms step_avg:97.78ms -step:1474/1695 train_time:144123ms step_avg:97.78ms -step:1475/1695 train_time:144221ms step_avg:97.78ms -step:1476/1695 train_time:144319ms step_avg:97.78ms -step:1477/1695 train_time:144418ms step_avg:97.78ms -step:1478/1695 train_time:144515ms step_avg:97.78ms -step:1479/1695 train_time:144613ms step_avg:97.78ms -step:1480/1695 train_time:144711ms step_avg:97.78ms -step:1481/1695 train_time:144809ms step_avg:97.78ms -step:1482/1695 train_time:144906ms step_avg:97.78ms -step:1483/1695 train_time:145004ms step_avg:97.78ms -step:1484/1695 train_time:145101ms step_avg:97.78ms -step:1485/1695 train_time:145200ms step_avg:97.78ms -step:1486/1695 train_time:145298ms step_avg:97.78ms -step:1487/1695 train_time:145397ms step_avg:97.78ms -step:1488/1695 train_time:145495ms step_avg:97.78ms -step:1489/1695 train_time:145593ms step_avg:97.78ms -step:1490/1695 train_time:145690ms step_avg:97.78ms -step:1491/1695 train_time:145787ms step_avg:97.78ms -step:1492/1695 train_time:145884ms step_avg:97.78ms -step:1493/1695 train_time:145982ms step_avg:97.78ms -step:1494/1695 train_time:146079ms step_avg:97.78ms -step:1495/1695 train_time:146177ms step_avg:97.78ms -step:1496/1695 train_time:146274ms step_avg:97.78ms -step:1497/1695 train_time:146372ms step_avg:97.78ms -step:1498/1695 train_time:146469ms step_avg:97.78ms -step:1499/1695 train_time:146567ms step_avg:97.78ms -step:1500/1695 train_time:146665ms step_avg:97.78ms -step:1500/1695 val_loss:3.3162 train_time:146761ms step_avg:97.84ms -step:1501/1695 train_time:146787ms step_avg:97.79ms -step:1502/1695 train_time:146870ms step_avg:97.78ms -step:1503/1695 train_time:146968ms step_avg:97.78ms -step:1504/1695 train_time:147065ms step_avg:97.78ms -step:1505/1695 train_time:147162ms step_avg:97.78ms -step:1506/1695 train_time:147259ms step_avg:97.78ms -step:1507/1695 train_time:147355ms step_avg:97.78ms -step:1508/1695 train_time:147452ms step_avg:97.78ms -step:1509/1695 train_time:147548ms step_avg:97.78ms -step:1510/1695 train_time:147645ms step_avg:97.78ms -step:1511/1695 train_time:147745ms step_avg:97.78ms -step:1512/1695 train_time:147846ms step_avg:97.78ms -step:1513/1695 train_time:147945ms step_avg:97.78ms -step:1514/1695 train_time:148043ms step_avg:97.78ms -step:1515/1695 train_time:148141ms step_avg:97.78ms -step:1516/1695 train_time:148240ms step_avg:97.78ms -step:1517/1695 train_time:148337ms step_avg:97.78ms -step:1518/1695 train_time:148433ms step_avg:97.78ms -step:1519/1695 train_time:148529ms step_avg:97.78ms -step:1520/1695 train_time:148627ms step_avg:97.78ms -step:1521/1695 train_time:148726ms step_avg:97.78ms -step:1522/1695 train_time:148825ms step_avg:97.78ms -step:1523/1695 train_time:148924ms step_avg:97.78ms -step:1524/1695 train_time:149022ms step_avg:97.78ms -step:1525/1695 train_time:149121ms step_avg:97.78ms -step:1526/1695 train_time:149219ms step_avg:97.78ms -step:1527/1695 train_time:149317ms step_avg:97.78ms -step:1528/1695 train_time:149413ms step_avg:97.78ms -step:1529/1695 train_time:149511ms step_avg:97.78ms -step:1530/1695 train_time:149608ms step_avg:97.78ms -step:1531/1695 train_time:149705ms step_avg:97.78ms -step:1532/1695 train_time:149804ms step_avg:97.78ms -step:1533/1695 train_time:149903ms step_avg:97.78ms -step:1534/1695 train_time:150001ms step_avg:97.78ms -step:1535/1695 train_time:150100ms step_avg:97.78ms -step:1536/1695 train_time:150198ms step_avg:97.79ms -step:1537/1695 train_time:150296ms step_avg:97.79ms -step:1538/1695 train_time:150393ms step_avg:97.78ms -step:1539/1695 train_time:150491ms step_avg:97.78ms -step:1540/1695 train_time:150588ms step_avg:97.78ms -step:1541/1695 train_time:150685ms step_avg:97.78ms -step:1542/1695 train_time:150783ms step_avg:97.78ms -step:1543/1695 train_time:150881ms step_avg:97.78ms -step:1544/1695 train_time:150981ms step_avg:97.79ms -step:1545/1695 train_time:151081ms step_avg:97.79ms -step:1546/1695 train_time:151181ms step_avg:97.79ms -step:1547/1695 train_time:151279ms step_avg:97.79ms -step:1548/1695 train_time:151378ms step_avg:97.79ms -step:1549/1695 train_time:151478ms step_avg:97.79ms -step:1550/1695 train_time:151576ms step_avg:97.79ms -step:1551/1695 train_time:151674ms step_avg:97.79ms -step:1552/1695 train_time:152071ms step_avg:97.98ms -step:1553/1695 train_time:152147ms step_avg:97.97ms -step:1554/1695 train_time:152242ms step_avg:97.97ms -step:1555/1695 train_time:152339ms step_avg:97.97ms -step:1556/1695 train_time:152435ms step_avg:97.97ms -step:1557/1695 train_time:152532ms step_avg:97.97ms -step:1558/1695 train_time:152628ms step_avg:97.96ms -step:1559/1695 train_time:152724ms step_avg:97.96ms -step:1560/1695 train_time:152821ms step_avg:97.96ms -step:1561/1695 train_time:152919ms step_avg:97.96ms -step:1562/1695 train_time:153018ms step_avg:97.96ms -step:1563/1695 train_time:153125ms step_avg:97.97ms -step:1564/1695 train_time:153224ms step_avg:97.97ms -step:1565/1695 train_time:153322ms step_avg:97.97ms -step:1566/1695 train_time:153420ms step_avg:97.97ms -step:1567/1695 train_time:153517ms step_avg:97.97ms -step:1568/1695 train_time:153616ms step_avg:97.97ms -step:1569/1695 train_time:153713ms step_avg:97.97ms -step:1570/1695 train_time:153810ms step_avg:97.97ms -step:1571/1695 train_time:153906ms step_avg:97.97ms -step:1572/1695 train_time:154004ms step_avg:97.97ms -step:1573/1695 train_time:154103ms step_avg:97.97ms -step:1574/1695 train_time:154203ms step_avg:97.97ms -step:1575/1695 train_time:154302ms step_avg:97.97ms -step:1576/1695 train_time:154400ms step_avg:97.97ms -step:1577/1695 train_time:154498ms step_avg:97.97ms -step:1578/1695 train_time:154596ms step_avg:97.97ms -step:1579/1695 train_time:154693ms step_avg:97.97ms -step:1580/1695 train_time:154791ms step_avg:97.97ms -step:1581/1695 train_time:154888ms step_avg:97.97ms -step:1582/1695 train_time:154985ms step_avg:97.97ms -step:1583/1695 train_time:155083ms step_avg:97.97ms -step:1584/1695 train_time:155182ms step_avg:97.97ms -step:1585/1695 train_time:155280ms step_avg:97.97ms -step:1586/1695 train_time:155380ms step_avg:97.97ms -step:1587/1695 train_time:155478ms step_avg:97.97ms -step:1588/1695 train_time:155575ms step_avg:97.97ms -step:1589/1695 train_time:155673ms step_avg:97.97ms -step:1590/1695 train_time:155771ms step_avg:97.97ms -step:1591/1695 train_time:155868ms step_avg:97.97ms -step:1592/1695 train_time:155965ms step_avg:97.97ms -step:1593/1695 train_time:156063ms step_avg:97.97ms -step:1594/1695 train_time:156160ms step_avg:97.97ms -step:1595/1695 train_time:156258ms step_avg:97.97ms -step:1596/1695 train_time:156357ms step_avg:97.97ms -step:1597/1695 train_time:156456ms step_avg:97.97ms -step:1598/1695 train_time:156554ms step_avg:97.97ms -step:1599/1695 train_time:156651ms step_avg:97.97ms -step:1600/1695 train_time:156748ms step_avg:97.97ms -step:1601/1695 train_time:156845ms step_avg:97.97ms -step:1602/1695 train_time:156943ms step_avg:97.97ms -step:1603/1695 train_time:157041ms step_avg:97.97ms -step:1604/1695 train_time:157140ms step_avg:97.97ms -step:1605/1695 train_time:157239ms step_avg:97.97ms -step:1606/1695 train_time:157337ms step_avg:97.97ms -step:1607/1695 train_time:157436ms step_avg:97.97ms -step:1608/1695 train_time:157535ms step_avg:97.97ms -step:1609/1695 train_time:157632ms step_avg:97.97ms -step:1610/1695 train_time:157730ms step_avg:97.97ms -step:1611/1695 train_time:157827ms step_avg:97.97ms -step:1612/1695 train_time:157925ms step_avg:97.97ms -step:1613/1695 train_time:158022ms step_avg:97.97ms -step:1614/1695 train_time:158119ms step_avg:97.97ms -step:1615/1695 train_time:158218ms step_avg:97.97ms -step:1616/1695 train_time:158317ms step_avg:97.97ms -step:1617/1695 train_time:158415ms step_avg:97.97ms -step:1618/1695 train_time:158513ms step_avg:97.97ms -step:1619/1695 train_time:158611ms step_avg:97.97ms -step:1620/1695 train_time:158708ms step_avg:97.97ms -step:1621/1695 train_time:158805ms step_avg:97.97ms -step:1622/1695 train_time:158903ms step_avg:97.97ms -step:1623/1695 train_time:159001ms step_avg:97.97ms -step:1624/1695 train_time:159099ms step_avg:97.97ms -step:1625/1695 train_time:159197ms step_avg:97.97ms -step:1625/1695 val_loss:3.2895 train_time:159292ms step_avg:98.03ms -step:1626/1695 train_time:159319ms step_avg:97.98ms -step:1627/1695 train_time:159403ms step_avg:97.97ms -step:1628/1695 train_time:159501ms step_avg:97.97ms -step:1629/1695 train_time:159598ms step_avg:97.97ms -step:1630/1695 train_time:159696ms step_avg:97.97ms -step:1631/1695 train_time:159793ms step_avg:97.97ms -step:1632/1695 train_time:159890ms step_avg:97.97ms -step:1633/1695 train_time:159986ms step_avg:97.97ms -step:1634/1695 train_time:160083ms step_avg:97.97ms -step:1635/1695 train_time:160179ms step_avg:97.97ms -step:1636/1695 train_time:160280ms step_avg:97.97ms -step:1637/1695 train_time:160382ms step_avg:97.97ms -step:1638/1695 train_time:160482ms step_avg:97.97ms -step:1639/1695 train_time:160580ms step_avg:97.97ms -step:1640/1695 train_time:160678ms step_avg:97.97ms -step:1641/1695 train_time:160774ms step_avg:97.97ms -step:1642/1695 train_time:160871ms step_avg:97.97ms -step:1643/1695 train_time:160969ms step_avg:97.97ms -step:1644/1695 train_time:161066ms step_avg:97.97ms -step:1645/1695 train_time:161162ms step_avg:97.97ms -step:1646/1695 train_time:161261ms step_avg:97.97ms -step:1647/1695 train_time:161362ms step_avg:97.97ms -step:1648/1695 train_time:161461ms step_avg:97.97ms -step:1649/1695 train_time:161559ms step_avg:97.97ms -step:1650/1695 train_time:161657ms step_avg:97.97ms -step:1651/1695 train_time:161755ms step_avg:97.97ms -step:1652/1695 train_time:161852ms step_avg:97.97ms -step:1653/1695 train_time:161951ms step_avg:97.97ms -step:1654/1695 train_time:162049ms step_avg:97.97ms -step:1655/1695 train_time:162146ms step_avg:97.97ms -step:1656/1695 train_time:162244ms step_avg:97.97ms -step:1657/1695 train_time:162342ms step_avg:97.97ms -step:1658/1695 train_time:162440ms step_avg:97.97ms -step:1659/1695 train_time:162538ms step_avg:97.97ms -step:1660/1695 train_time:162636ms step_avg:97.97ms -step:1661/1695 train_time:162734ms step_avg:97.97ms -step:1662/1695 train_time:162831ms step_avg:97.97ms -step:1663/1695 train_time:162928ms step_avg:97.97ms -step:1664/1695 train_time:163026ms step_avg:97.97ms -step:1665/1695 train_time:163122ms step_avg:97.97ms -step:1666/1695 train_time:163221ms step_avg:97.97ms -step:1667/1695 train_time:163320ms step_avg:97.97ms -step:1668/1695 train_time:163418ms step_avg:97.97ms -step:1669/1695 train_time:163518ms step_avg:97.97ms -step:1670/1695 train_time:163617ms step_avg:97.97ms -step:1671/1695 train_time:163715ms step_avg:97.97ms -step:1672/1695 train_time:163812ms step_avg:97.97ms -step:1673/1695 train_time:163911ms step_avg:97.97ms -step:1674/1695 train_time:164008ms step_avg:97.97ms -step:1675/1695 train_time:164105ms step_avg:97.97ms -step:1676/1695 train_time:164202ms step_avg:97.97ms -step:1677/1695 train_time:164300ms step_avg:97.97ms -step:1678/1695 train_time:164397ms step_avg:97.97ms -step:1679/1695 train_time:164495ms step_avg:97.97ms -step:1680/1695 train_time:164593ms step_avg:97.97ms -step:1681/1695 train_time:164691ms step_avg:97.97ms -step:1682/1695 train_time:164789ms step_avg:97.97ms -step:1683/1695 train_time:164886ms step_avg:97.97ms -step:1684/1695 train_time:164984ms step_avg:97.97ms -step:1685/1695 train_time:165081ms step_avg:97.97ms -step:1686/1695 train_time:165179ms step_avg:97.97ms -step:1687/1695 train_time:165278ms step_avg:97.97ms -step:1688/1695 train_time:165377ms step_avg:97.97ms -step:1689/1695 train_time:165475ms step_avg:97.97ms -step:1690/1695 train_time:165573ms step_avg:97.97ms -step:1691/1695 train_time:165672ms step_avg:97.97ms -step:1692/1695 train_time:165770ms step_avg:97.97ms -step:1693/1695 train_time:165868ms step_avg:97.97ms -step:1694/1695 train_time:165966ms step_avg:97.97ms -step:1695/1695 train_time:166062ms step_avg:97.97ms -step:1695/1695 val_loss:3.2782 train_time:166157ms step_avg:98.03ms -peak memory allocated: 34361 MiB reserved: 49576 MiB diff --git a/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt b/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt deleted file mode 100644 index 32ec95b7e..000000000 --- a/records/082725_FA3/1d46fee6-b32c-48de-bd61-0a326442ec4e.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 04:04:43 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 29C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 29C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 29C P0 109W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 31C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms -step:1/1695 train_time:517ms step_avg:517.24ms -step:2/1695 train_time:541ms step_avg:270.47ms -step:3/1695 train_time:610ms step_avg:203.20ms -step:4/1695 train_time:702ms step_avg:175.48ms -step:5/1695 train_time:795ms step_avg:158.96ms -step:6/1695 train_time:888ms step_avg:148.00ms -step:7/1695 train_time:982ms step_avg:140.22ms -step:8/1695 train_time:1075ms step_avg:134.38ms -step:9/1695 train_time:1169ms step_avg:129.83ms -step:10/1695 train_time:1262ms step_avg:126.22ms -step:11/1695 train_time:1356ms step_avg:123.26ms -step:12/1695 train_time:1454ms step_avg:121.13ms -step:13/1695 train_time:1553ms step_avg:119.44ms -step:14/1695 train_time:1650ms step_avg:117.83ms -step:15/1695 train_time:1744ms step_avg:116.30ms -step:16/1695 train_time:1838ms step_avg:114.86ms -step:17/1695 train_time:1931ms step_avg:113.60ms -step:18/1695 train_time:2025ms step_avg:112.49ms -step:19/1695 train_time:2118ms step_avg:111.45ms -step:20/1695 train_time:2211ms step_avg:110.56ms -step:21/1695 train_time:2306ms step_avg:109.79ms -step:22/1695 train_time:2401ms step_avg:109.13ms -step:23/1695 train_time:2497ms step_avg:108.55ms -step:24/1695 train_time:2592ms step_avg:108.02ms -step:25/1695 train_time:2689ms step_avg:107.56ms -step:26/1695 train_time:2785ms step_avg:107.10ms -step:27/1695 train_time:2879ms step_avg:106.62ms -step:28/1695 train_time:2973ms step_avg:106.16ms -step:29/1695 train_time:3067ms step_avg:105.75ms -step:30/1695 train_time:3160ms step_avg:105.34ms -step:31/1695 train_time:3253ms step_avg:104.95ms -step:32/1695 train_time:3349ms step_avg:104.66ms -step:33/1695 train_time:3445ms step_avg:104.39ms -step:34/1695 train_time:3540ms step_avg:104.10ms -step:35/1695 train_time:3634ms step_avg:103.84ms -step:36/1695 train_time:3731ms step_avg:103.64ms -step:37/1695 train_time:3827ms step_avg:103.44ms -step:38/1695 train_time:3922ms step_avg:103.22ms -step:39/1695 train_time:4016ms step_avg:102.98ms -step:40/1695 train_time:4111ms step_avg:102.77ms -step:41/1695 train_time:4205ms step_avg:102.56ms -step:42/1695 train_time:4299ms step_avg:102.35ms -step:43/1695 train_time:4393ms step_avg:102.16ms -step:44/1695 train_time:4488ms step_avg:102.01ms -step:45/1695 train_time:4584ms step_avg:101.86ms -step:46/1695 train_time:4678ms step_avg:101.68ms -step:47/1695 train_time:4773ms step_avg:101.54ms -step:48/1695 train_time:4867ms step_avg:101.40ms -step:49/1695 train_time:4962ms step_avg:101.27ms -step:50/1695 train_time:5056ms step_avg:101.12ms -step:51/1695 train_time:5150ms step_avg:100.98ms -step:52/1695 train_time:5244ms step_avg:100.85ms -step:53/1695 train_time:5338ms step_avg:100.72ms -step:54/1695 train_time:5432ms step_avg:100.60ms -step:55/1695 train_time:5528ms step_avg:100.51ms -step:56/1695 train_time:5623ms step_avg:100.42ms -step:57/1695 train_time:5717ms step_avg:100.31ms -step:58/1695 train_time:5812ms step_avg:100.20ms -step:59/1695 train_time:5908ms step_avg:100.13ms -step:60/1695 train_time:6003ms step_avg:100.04ms -step:61/1695 train_time:6096ms step_avg:99.94ms -step:62/1695 train_time:6190ms step_avg:99.84ms -step:63/1695 train_time:6284ms step_avg:99.75ms -step:64/1695 train_time:6378ms step_avg:99.66ms -step:65/1695 train_time:6472ms step_avg:99.57ms -step:66/1695 train_time:6568ms step_avg:99.51ms -step:67/1695 train_time:6663ms step_avg:99.45ms -step:68/1695 train_time:6756ms step_avg:99.36ms -step:69/1695 train_time:6851ms step_avg:99.29ms -step:70/1695 train_time:6946ms step_avg:99.23ms -step:71/1695 train_time:7040ms step_avg:99.16ms -step:72/1695 train_time:7133ms step_avg:99.07ms -step:73/1695 train_time:7228ms step_avg:99.02ms -step:74/1695 train_time:7324ms step_avg:98.97ms -step:75/1695 train_time:7418ms step_avg:98.90ms -step:76/1695 train_time:7512ms step_avg:98.84ms -step:77/1695 train_time:7607ms step_avg:98.80ms -step:78/1695 train_time:7702ms step_avg:98.75ms -step:79/1695 train_time:7796ms step_avg:98.69ms -step:80/1695 train_time:7891ms step_avg:98.63ms -step:81/1695 train_time:7985ms step_avg:98.58ms -step:82/1695 train_time:8079ms step_avg:98.52ms -step:83/1695 train_time:8172ms step_avg:98.46ms -step:84/1695 train_time:8268ms step_avg:98.42ms -step:85/1695 train_time:8363ms step_avg:98.39ms -step:86/1695 train_time:8457ms step_avg:98.33ms -step:87/1695 train_time:8551ms step_avg:98.29ms -step:88/1695 train_time:8647ms step_avg:98.26ms -step:89/1695 train_time:8742ms step_avg:98.22ms -step:90/1695 train_time:8835ms step_avg:98.17ms -step:91/1695 train_time:8931ms step_avg:98.14ms -step:92/1695 train_time:9024ms step_avg:98.09ms -step:93/1695 train_time:9118ms step_avg:98.04ms -step:94/1695 train_time:9211ms step_avg:97.99ms -step:95/1695 train_time:9305ms step_avg:97.95ms -step:96/1695 train_time:9399ms step_avg:97.91ms -step:97/1695 train_time:9494ms step_avg:97.87ms -step:98/1695 train_time:9588ms step_avg:97.84ms -step:99/1695 train_time:9684ms step_avg:97.81ms -step:100/1695 train_time:9777ms step_avg:97.77ms -step:101/1695 train_time:9872ms step_avg:97.75ms -step:102/1695 train_time:9968ms step_avg:97.73ms -step:103/1695 train_time:10063ms step_avg:97.70ms -step:104/1695 train_time:10157ms step_avg:97.66ms -step:105/1695 train_time:10251ms step_avg:97.63ms -step:106/1695 train_time:10345ms step_avg:97.59ms -step:107/1695 train_time:10439ms step_avg:97.56ms -step:108/1695 train_time:10533ms step_avg:97.52ms -step:109/1695 train_time:10627ms step_avg:97.50ms -step:110/1695 train_time:10721ms step_avg:97.46ms -step:111/1695 train_time:10814ms step_avg:97.42ms -step:112/1695 train_time:10909ms step_avg:97.40ms -step:113/1695 train_time:11004ms step_avg:97.38ms -step:114/1695 train_time:11098ms step_avg:97.35ms -step:115/1695 train_time:11192ms step_avg:97.33ms -step:116/1695 train_time:11288ms step_avg:97.31ms -step:117/1695 train_time:11381ms step_avg:97.28ms -step:118/1695 train_time:11476ms step_avg:97.25ms -step:119/1695 train_time:11570ms step_avg:97.23ms -step:120/1695 train_time:11664ms step_avg:97.20ms -step:121/1695 train_time:11758ms step_avg:97.17ms -step:122/1695 train_time:11852ms step_avg:97.15ms -step:123/1695 train_time:11947ms step_avg:97.13ms -step:124/1695 train_time:12042ms step_avg:97.11ms -step:125/1695 train_time:12135ms step_avg:97.08ms -step:125/1695 val_loss:4.3142 train_time:12227ms step_avg:97.82ms -step:126/1695 train_time:12252ms step_avg:97.24ms -step:127/1695 train_time:12329ms step_avg:97.08ms -step:128/1695 train_time:12428ms step_avg:97.09ms -step:129/1695 train_time:12522ms step_avg:97.07ms -step:130/1695 train_time:12616ms step_avg:97.05ms -step:131/1695 train_time:12710ms step_avg:97.02ms -step:132/1695 train_time:12803ms step_avg:96.99ms -step:133/1695 train_time:12896ms step_avg:96.96ms -step:134/1695 train_time:12990ms step_avg:96.94ms -step:135/1695 train_time:13083ms step_avg:96.91ms -step:136/1695 train_time:13177ms step_avg:96.89ms -step:137/1695 train_time:13273ms step_avg:96.88ms -step:138/1695 train_time:13370ms step_avg:96.88ms -step:139/1695 train_time:13466ms step_avg:96.88ms -step:140/1695 train_time:13560ms step_avg:96.86ms -step:141/1695 train_time:13654ms step_avg:96.84ms -step:142/1695 train_time:13748ms step_avg:96.82ms -step:143/1695 train_time:13841ms step_avg:96.79ms -step:144/1695 train_time:13935ms step_avg:96.77ms -step:145/1695 train_time:14029ms step_avg:96.75ms -step:146/1695 train_time:14122ms step_avg:96.72ms -step:147/1695 train_time:14215ms step_avg:96.70ms -step:148/1695 train_time:14312ms step_avg:96.71ms -step:149/1695 train_time:14409ms step_avg:96.70ms -step:150/1695 train_time:14503ms step_avg:96.69ms -step:151/1695 train_time:14597ms step_avg:96.67ms -step:152/1695 train_time:14692ms step_avg:96.66ms -step:153/1695 train_time:14787ms step_avg:96.65ms -step:154/1695 train_time:14881ms step_avg:96.63ms -step:155/1695 train_time:14974ms step_avg:96.61ms -step:156/1695 train_time:15068ms step_avg:96.59ms -step:157/1695 train_time:15161ms step_avg:96.56ms -step:158/1695 train_time:15255ms step_avg:96.55ms -step:159/1695 train_time:15350ms step_avg:96.54ms -step:160/1695 train_time:15446ms step_avg:96.53ms -step:161/1695 train_time:15540ms step_avg:96.52ms -step:162/1695 train_time:15635ms step_avg:96.51ms -step:163/1695 train_time:15729ms step_avg:96.50ms -step:164/1695 train_time:15824ms step_avg:96.49ms -step:165/1695 train_time:15917ms step_avg:96.47ms -step:166/1695 train_time:16012ms step_avg:96.46ms -step:167/1695 train_time:16106ms step_avg:96.44ms -step:168/1695 train_time:16199ms step_avg:96.42ms -step:169/1695 train_time:16293ms step_avg:96.41ms -step:170/1695 train_time:16388ms step_avg:96.40ms -step:171/1695 train_time:16482ms step_avg:96.39ms -step:172/1695 train_time:16576ms step_avg:96.37ms -step:173/1695 train_time:16958ms step_avg:98.02ms -step:174/1695 train_time:17044ms step_avg:97.95ms -step:175/1695 train_time:17136ms step_avg:97.92ms -step:176/1695 train_time:17228ms step_avg:97.89ms -step:177/1695 train_time:17321ms step_avg:97.86ms -step:178/1695 train_time:17414ms step_avg:97.83ms -step:179/1695 train_time:17507ms step_avg:97.80ms -step:180/1695 train_time:17599ms step_avg:97.77ms -step:181/1695 train_time:17693ms step_avg:97.75ms -step:182/1695 train_time:17787ms step_avg:97.73ms -step:183/1695 train_time:17880ms step_avg:97.71ms -step:184/1695 train_time:17978ms step_avg:97.71ms -step:185/1695 train_time:18075ms step_avg:97.71ms -step:186/1695 train_time:18170ms step_avg:97.69ms -step:187/1695 train_time:18264ms step_avg:97.67ms -step:188/1695 train_time:18357ms step_avg:97.64ms -step:189/1695 train_time:18451ms step_avg:97.62ms -step:190/1695 train_time:18545ms step_avg:97.60ms -step:191/1695 train_time:18637ms step_avg:97.58ms -step:192/1695 train_time:18731ms step_avg:97.56ms -step:193/1695 train_time:18825ms step_avg:97.54ms -step:194/1695 train_time:18919ms step_avg:97.52ms -step:195/1695 train_time:19016ms step_avg:97.52ms -step:196/1695 train_time:19111ms step_avg:97.51ms -step:197/1695 train_time:19207ms step_avg:97.50ms -step:198/1695 train_time:19300ms step_avg:97.48ms -step:199/1695 train_time:19395ms step_avg:97.46ms -step:200/1695 train_time:19489ms step_avg:97.44ms -step:201/1695 train_time:19582ms step_avg:97.42ms -step:202/1695 train_time:19676ms step_avg:97.40ms -step:203/1695 train_time:19769ms step_avg:97.39ms -step:204/1695 train_time:19863ms step_avg:97.37ms -step:205/1695 train_time:19957ms step_avg:97.35ms -step:206/1695 train_time:20052ms step_avg:97.34ms -step:207/1695 train_time:20148ms step_avg:97.33ms -step:208/1695 train_time:20242ms step_avg:97.32ms -step:209/1695 train_time:20336ms step_avg:97.30ms -step:210/1695 train_time:20430ms step_avg:97.29ms -step:211/1695 train_time:20523ms step_avg:97.27ms -step:212/1695 train_time:20617ms step_avg:97.25ms -step:213/1695 train_time:20710ms step_avg:97.23ms -step:214/1695 train_time:20804ms step_avg:97.21ms -step:215/1695 train_time:20897ms step_avg:97.20ms -step:216/1695 train_time:20991ms step_avg:97.18ms -step:217/1695 train_time:21086ms step_avg:97.17ms -step:218/1695 train_time:21180ms step_avg:97.15ms -step:219/1695 train_time:21274ms step_avg:97.14ms -step:220/1695 train_time:21369ms step_avg:97.13ms -step:221/1695 train_time:21463ms step_avg:97.12ms -step:222/1695 train_time:21557ms step_avg:97.10ms -step:223/1695 train_time:21651ms step_avg:97.09ms -step:224/1695 train_time:21745ms step_avg:97.07ms -step:225/1695 train_time:21838ms step_avg:97.06ms -step:226/1695 train_time:21933ms step_avg:97.05ms -step:227/1695 train_time:22028ms step_avg:97.04ms -step:228/1695 train_time:22122ms step_avg:97.03ms -step:229/1695 train_time:22216ms step_avg:97.01ms -step:230/1695 train_time:22311ms step_avg:97.01ms -step:231/1695 train_time:22406ms step_avg:97.00ms -step:232/1695 train_time:22500ms step_avg:96.98ms -step:233/1695 train_time:22593ms step_avg:96.97ms -step:234/1695 train_time:22688ms step_avg:96.96ms -step:235/1695 train_time:22781ms step_avg:96.94ms -step:236/1695 train_time:22874ms step_avg:96.93ms -step:237/1695 train_time:22969ms step_avg:96.91ms -step:238/1695 train_time:23062ms step_avg:96.90ms -step:239/1695 train_time:23155ms step_avg:96.88ms -step:240/1695 train_time:23249ms step_avg:96.87ms -step:241/1695 train_time:23343ms step_avg:96.86ms -step:242/1695 train_time:23437ms step_avg:96.85ms -step:243/1695 train_time:23531ms step_avg:96.84ms -step:244/1695 train_time:23625ms step_avg:96.82ms -step:245/1695 train_time:23719ms step_avg:96.81ms -step:246/1695 train_time:23813ms step_avg:96.80ms -step:247/1695 train_time:23908ms step_avg:96.79ms -step:248/1695 train_time:24002ms step_avg:96.78ms -step:249/1695 train_time:24096ms step_avg:96.77ms -step:250/1695 train_time:24190ms step_avg:96.76ms -step:250/1695 val_loss:3.9738 train_time:24282ms step_avg:97.13ms -step:251/1695 train_time:24306ms step_avg:96.84ms -step:252/1695 train_time:24385ms step_avg:96.77ms -step:253/1695 train_time:24484ms step_avg:96.78ms -step:254/1695 train_time:24579ms step_avg:96.77ms -step:255/1695 train_time:24672ms step_avg:96.75ms -step:256/1695 train_time:24766ms step_avg:96.74ms -step:257/1695 train_time:24858ms step_avg:96.73ms -step:258/1695 train_time:24951ms step_avg:96.71ms -step:259/1695 train_time:25044ms step_avg:96.69ms -step:260/1695 train_time:25137ms step_avg:96.68ms -step:261/1695 train_time:25231ms step_avg:96.67ms -step:262/1695 train_time:25325ms step_avg:96.66ms -step:263/1695 train_time:25423ms step_avg:96.66ms -step:264/1695 train_time:25519ms step_avg:96.66ms -step:265/1695 train_time:25614ms step_avg:96.66ms -step:266/1695 train_time:25708ms step_avg:96.65ms -step:267/1695 train_time:25801ms step_avg:96.63ms -step:268/1695 train_time:25895ms step_avg:96.62ms -step:269/1695 train_time:25988ms step_avg:96.61ms -step:270/1695 train_time:26081ms step_avg:96.60ms -step:271/1695 train_time:26174ms step_avg:96.58ms -step:272/1695 train_time:26267ms step_avg:96.57ms -step:273/1695 train_time:26362ms step_avg:96.57ms -step:274/1695 train_time:26458ms step_avg:96.56ms -step:275/1695 train_time:26553ms step_avg:96.56ms -step:276/1695 train_time:26648ms step_avg:96.55ms -step:277/1695 train_time:26741ms step_avg:96.54ms -step:278/1695 train_time:26836ms step_avg:96.53ms -step:279/1695 train_time:26931ms step_avg:96.53ms -step:280/1695 train_time:27023ms step_avg:96.51ms -step:281/1695 train_time:27117ms step_avg:96.50ms -step:282/1695 train_time:27210ms step_avg:96.49ms -step:283/1695 train_time:27303ms step_avg:96.48ms -step:284/1695 train_time:27398ms step_avg:96.47ms -step:285/1695 train_time:27492ms step_avg:96.46ms -step:286/1695 train_time:27587ms step_avg:96.46ms -step:287/1695 train_time:27681ms step_avg:96.45ms -step:288/1695 train_time:27775ms step_avg:96.44ms -step:289/1695 train_time:27870ms step_avg:96.44ms -step:290/1695 train_time:27963ms step_avg:96.43ms -step:291/1695 train_time:28057ms step_avg:96.42ms -step:292/1695 train_time:28151ms step_avg:96.41ms -step:293/1695 train_time:28244ms step_avg:96.39ms -step:294/1695 train_time:28338ms step_avg:96.39ms -step:295/1695 train_time:28432ms step_avg:96.38ms -step:296/1695 train_time:28526ms step_avg:96.37ms -step:297/1695 train_time:28620ms step_avg:96.36ms -step:298/1695 train_time:28714ms step_avg:96.35ms -step:299/1695 train_time:28808ms step_avg:96.35ms -step:300/1695 train_time:28901ms step_avg:96.34ms -step:301/1695 train_time:28995ms step_avg:96.33ms -step:302/1695 train_time:29089ms step_avg:96.32ms -step:303/1695 train_time:29182ms step_avg:96.31ms -step:304/1695 train_time:29276ms step_avg:96.30ms -step:305/1695 train_time:29370ms step_avg:96.30ms -step:306/1695 train_time:29464ms step_avg:96.29ms -step:307/1695 train_time:29558ms step_avg:96.28ms -step:308/1695 train_time:29653ms step_avg:96.28ms -step:309/1695 train_time:29748ms step_avg:96.27ms -step:310/1695 train_time:29841ms step_avg:96.26ms -step:311/1695 train_time:29936ms step_avg:96.26ms -step:312/1695 train_time:30031ms step_avg:96.25ms -step:313/1695 train_time:30124ms step_avg:96.24ms -step:314/1695 train_time:30218ms step_avg:96.24ms -step:315/1695 train_time:30312ms step_avg:96.23ms -step:316/1695 train_time:30405ms step_avg:96.22ms -step:317/1695 train_time:30500ms step_avg:96.21ms -step:318/1695 train_time:30595ms step_avg:96.21ms -step:319/1695 train_time:30689ms step_avg:96.20ms -step:320/1695 train_time:30783ms step_avg:96.20ms -step:321/1695 train_time:30877ms step_avg:96.19ms -step:322/1695 train_time:30972ms step_avg:96.19ms -step:323/1695 train_time:31066ms step_avg:96.18ms -step:324/1695 train_time:31159ms step_avg:96.17ms -step:325/1695 train_time:31252ms step_avg:96.16ms -step:326/1695 train_time:31346ms step_avg:96.15ms -step:327/1695 train_time:31440ms step_avg:96.15ms -step:328/1695 train_time:31535ms step_avg:96.14ms -step:329/1695 train_time:31630ms step_avg:96.14ms -step:330/1695 train_time:31723ms step_avg:96.13ms -step:331/1695 train_time:31817ms step_avg:96.12ms -step:332/1695 train_time:31912ms step_avg:96.12ms -step:333/1695 train_time:32006ms step_avg:96.11ms -step:334/1695 train_time:32099ms step_avg:96.11ms -step:335/1695 train_time:32194ms step_avg:96.10ms -step:336/1695 train_time:32288ms step_avg:96.09ms -step:337/1695 train_time:32381ms step_avg:96.09ms -step:338/1695 train_time:32475ms step_avg:96.08ms -step:339/1695 train_time:32570ms step_avg:96.08ms -step:340/1695 train_time:32664ms step_avg:96.07ms -step:341/1695 train_time:32758ms step_avg:96.06ms -step:342/1695 train_time:32852ms step_avg:96.06ms -step:343/1695 train_time:32947ms step_avg:96.05ms -step:344/1695 train_time:33040ms step_avg:96.05ms -step:345/1695 train_time:33366ms step_avg:96.71ms -step:346/1695 train_time:33470ms step_avg:96.73ms -step:347/1695 train_time:33562ms step_avg:96.72ms -step:348/1695 train_time:33655ms step_avg:96.71ms -step:349/1695 train_time:33748ms step_avg:96.70ms -step:350/1695 train_time:33840ms step_avg:96.69ms -step:351/1695 train_time:33933ms step_avg:96.68ms -step:352/1695 train_time:34026ms step_avg:96.66ms -step:353/1695 train_time:34119ms step_avg:96.65ms -step:354/1695 train_time:34212ms step_avg:96.65ms -step:355/1695 train_time:34310ms step_avg:96.65ms -step:356/1695 train_time:34408ms step_avg:96.65ms -step:357/1695 train_time:34502ms step_avg:96.65ms -step:358/1695 train_time:34597ms step_avg:96.64ms -step:359/1695 train_time:34690ms step_avg:96.63ms -step:360/1695 train_time:34783ms step_avg:96.62ms -step:361/1695 train_time:34876ms step_avg:96.61ms -step:362/1695 train_time:34970ms step_avg:96.60ms -step:363/1695 train_time:35062ms step_avg:96.59ms -step:364/1695 train_time:35155ms step_avg:96.58ms -step:365/1695 train_time:35250ms step_avg:96.58ms -step:366/1695 train_time:35345ms step_avg:96.57ms -step:367/1695 train_time:35440ms step_avg:96.57ms -step:368/1695 train_time:35535ms step_avg:96.56ms -step:369/1695 train_time:35630ms step_avg:96.56ms -step:370/1695 train_time:35723ms step_avg:96.55ms -step:371/1695 train_time:35817ms step_avg:96.54ms -step:372/1695 train_time:35911ms step_avg:96.53ms -step:373/1695 train_time:36004ms step_avg:96.52ms -step:374/1695 train_time:36097ms step_avg:96.52ms -step:375/1695 train_time:36191ms step_avg:96.51ms -step:375/1695 val_loss:3.8151 train_time:36283ms step_avg:96.76ms -step:376/1695 train_time:36310ms step_avg:96.57ms -step:377/1695 train_time:36385ms step_avg:96.51ms -step:378/1695 train_time:36485ms step_avg:96.52ms -step:379/1695 train_time:36582ms step_avg:96.52ms -step:380/1695 train_time:36675ms step_avg:96.51ms -step:381/1695 train_time:36768ms step_avg:96.50ms -step:382/1695 train_time:36861ms step_avg:96.50ms -step:383/1695 train_time:36955ms step_avg:96.49ms -step:384/1695 train_time:37047ms step_avg:96.48ms -step:385/1695 train_time:37140ms step_avg:96.47ms -step:386/1695 train_time:37233ms step_avg:96.46ms -step:387/1695 train_time:37328ms step_avg:96.46ms -step:388/1695 train_time:37424ms step_avg:96.45ms -step:389/1695 train_time:37521ms step_avg:96.46ms -step:390/1695 train_time:37617ms step_avg:96.45ms -step:391/1695 train_time:37711ms step_avg:96.45ms -step:392/1695 train_time:37804ms step_avg:96.44ms -step:393/1695 train_time:37897ms step_avg:96.43ms -step:394/1695 train_time:37990ms step_avg:96.42ms -step:395/1695 train_time:38083ms step_avg:96.41ms -step:396/1695 train_time:38176ms step_avg:96.40ms -step:397/1695 train_time:38270ms step_avg:96.40ms -step:398/1695 train_time:38364ms step_avg:96.39ms -step:399/1695 train_time:38460ms step_avg:96.39ms -step:400/1695 train_time:38556ms step_avg:96.39ms -step:401/1695 train_time:38650ms step_avg:96.38ms -step:402/1695 train_time:38744ms step_avg:96.38ms -step:403/1695 train_time:38837ms step_avg:96.37ms -step:404/1695 train_time:38930ms step_avg:96.36ms -step:405/1695 train_time:39024ms step_avg:96.35ms -step:406/1695 train_time:39118ms step_avg:96.35ms -step:407/1695 train_time:39211ms step_avg:96.34ms -step:408/1695 train_time:39304ms step_avg:96.33ms -step:409/1695 train_time:39398ms step_avg:96.33ms -step:410/1695 train_time:39494ms step_avg:96.33ms -step:411/1695 train_time:39588ms step_avg:96.32ms -step:412/1695 train_time:39683ms step_avg:96.32ms -step:413/1695 train_time:39776ms step_avg:96.31ms -step:414/1695 train_time:39870ms step_avg:96.30ms -step:415/1695 train_time:39963ms step_avg:96.30ms -step:416/1695 train_time:40058ms step_avg:96.29ms -step:417/1695 train_time:40152ms step_avg:96.29ms -step:418/1695 train_time:40246ms step_avg:96.28ms -step:419/1695 train_time:40339ms step_avg:96.27ms -step:420/1695 train_time:40433ms step_avg:96.27ms -step:421/1695 train_time:40528ms step_avg:96.27ms -step:422/1695 train_time:40622ms step_avg:96.26ms -step:423/1695 train_time:40716ms step_avg:96.26ms -step:424/1695 train_time:40810ms step_avg:96.25ms -step:425/1695 train_time:40903ms step_avg:96.24ms -step:426/1695 train_time:40997ms step_avg:96.24ms -step:427/1695 train_time:41091ms step_avg:96.23ms -step:428/1695 train_time:41184ms step_avg:96.22ms -step:429/1695 train_time:41278ms step_avg:96.22ms -step:430/1695 train_time:41372ms step_avg:96.21ms -step:431/1695 train_time:41466ms step_avg:96.21ms -step:432/1695 train_time:41561ms step_avg:96.21ms -step:433/1695 train_time:41656ms step_avg:96.20ms -step:434/1695 train_time:41750ms step_avg:96.20ms -step:435/1695 train_time:41843ms step_avg:96.19ms -step:436/1695 train_time:41938ms step_avg:96.19ms -step:437/1695 train_time:42032ms step_avg:96.18ms -step:438/1695 train_time:42126ms step_avg:96.18ms -step:439/1695 train_time:42220ms step_avg:96.17ms -step:440/1695 train_time:42314ms step_avg:96.17ms -step:441/1695 train_time:42408ms step_avg:96.16ms -step:442/1695 train_time:42501ms step_avg:96.16ms -step:443/1695 train_time:42596ms step_avg:96.15ms -step:444/1695 train_time:42691ms step_avg:96.15ms -step:445/1695 train_time:42784ms step_avg:96.14ms -step:446/1695 train_time:42878ms step_avg:96.14ms -step:447/1695 train_time:42972ms step_avg:96.13ms -step:448/1695 train_time:43066ms step_avg:96.13ms -step:449/1695 train_time:43160ms step_avg:96.12ms -step:450/1695 train_time:43255ms step_avg:96.12ms -step:451/1695 train_time:43348ms step_avg:96.12ms -step:452/1695 train_time:43443ms step_avg:96.11ms -step:453/1695 train_time:43537ms step_avg:96.11ms -step:454/1695 train_time:43632ms step_avg:96.11ms -step:455/1695 train_time:43725ms step_avg:96.10ms -step:456/1695 train_time:43819ms step_avg:96.09ms -step:457/1695 train_time:43913ms step_avg:96.09ms -step:458/1695 train_time:44006ms step_avg:96.08ms -step:459/1695 train_time:44100ms step_avg:96.08ms -step:460/1695 train_time:44194ms step_avg:96.07ms -step:461/1695 train_time:44287ms step_avg:96.07ms -step:462/1695 train_time:44381ms step_avg:96.06ms -step:463/1695 train_time:44476ms step_avg:96.06ms -step:464/1695 train_time:44569ms step_avg:96.05ms -step:465/1695 train_time:44663ms step_avg:96.05ms -step:466/1695 train_time:44758ms step_avg:96.05ms -step:467/1695 train_time:44853ms step_avg:96.04ms -step:468/1695 train_time:44947ms step_avg:96.04ms -step:469/1695 train_time:45040ms step_avg:96.03ms -step:470/1695 train_time:45134ms step_avg:96.03ms -step:471/1695 train_time:45229ms step_avg:96.03ms -step:472/1695 train_time:45323ms step_avg:96.02ms -step:473/1695 train_time:45418ms step_avg:96.02ms -step:474/1695 train_time:45512ms step_avg:96.02ms -step:475/1695 train_time:45605ms step_avg:96.01ms -step:476/1695 train_time:45699ms step_avg:96.01ms -step:477/1695 train_time:45795ms step_avg:96.01ms -step:478/1695 train_time:45888ms step_avg:96.00ms -step:479/1695 train_time:45982ms step_avg:96.00ms -step:480/1695 train_time:46076ms step_avg:95.99ms -step:481/1695 train_time:46170ms step_avg:95.99ms -step:482/1695 train_time:46264ms step_avg:95.98ms -step:483/1695 train_time:46358ms step_avg:95.98ms -step:484/1695 train_time:46454ms step_avg:95.98ms -step:485/1695 train_time:46547ms step_avg:95.97ms -step:486/1695 train_time:46641ms step_avg:95.97ms -step:487/1695 train_time:46735ms step_avg:95.97ms -step:488/1695 train_time:46830ms step_avg:95.96ms -step:489/1695 train_time:46924ms step_avg:95.96ms -step:490/1695 train_time:47018ms step_avg:95.96ms -step:491/1695 train_time:47112ms step_avg:95.95ms -step:492/1695 train_time:47205ms step_avg:95.95ms -step:493/1695 train_time:47299ms step_avg:95.94ms -step:494/1695 train_time:47393ms step_avg:95.94ms -step:495/1695 train_time:47487ms step_avg:95.93ms -step:496/1695 train_time:47581ms step_avg:95.93ms -step:497/1695 train_time:47675ms step_avg:95.93ms -step:498/1695 train_time:47768ms step_avg:95.92ms -step:499/1695 train_time:47862ms step_avg:95.92ms -step:500/1695 train_time:47957ms step_avg:95.91ms -step:500/1695 val_loss:3.7158 train_time:48050ms step_avg:96.10ms -step:501/1695 train_time:48074ms step_avg:95.96ms -step:502/1695 train_time:48155ms step_avg:95.93ms -step:503/1695 train_time:48257ms step_avg:95.94ms -step:504/1695 train_time:48350ms step_avg:95.93ms -step:505/1695 train_time:48444ms step_avg:95.93ms -step:506/1695 train_time:48537ms step_avg:95.92ms -step:507/1695 train_time:48630ms step_avg:95.92ms -step:508/1695 train_time:48723ms step_avg:95.91ms -step:509/1695 train_time:48816ms step_avg:95.91ms -step:510/1695 train_time:48909ms step_avg:95.90ms -step:511/1695 train_time:49002ms step_avg:95.89ms -step:512/1695 train_time:49098ms step_avg:95.89ms -step:513/1695 train_time:49195ms step_avg:95.90ms -step:514/1695 train_time:49290ms step_avg:95.90ms -step:515/1695 train_time:49386ms step_avg:95.89ms -step:516/1695 train_time:49480ms step_avg:95.89ms -step:517/1695 train_time:49573ms step_avg:95.89ms -step:518/1695 train_time:49666ms step_avg:95.88ms -step:519/1695 train_time:49996ms step_avg:96.33ms -step:520/1695 train_time:50189ms step_avg:96.52ms -step:521/1695 train_time:50281ms step_avg:96.51ms -step:522/1695 train_time:50374ms step_avg:96.50ms -step:523/1695 train_time:50467ms step_avg:96.50ms -step:524/1695 train_time:50560ms step_avg:96.49ms -step:525/1695 train_time:50653ms step_avg:96.48ms -step:526/1695 train_time:50746ms step_avg:96.48ms -step:527/1695 train_time:50840ms step_avg:96.47ms -step:528/1695 train_time:50933ms step_avg:96.46ms -step:529/1695 train_time:51027ms step_avg:96.46ms -step:530/1695 train_time:51125ms step_avg:96.46ms -step:531/1695 train_time:51222ms step_avg:96.46ms -step:532/1695 train_time:51316ms step_avg:96.46ms -step:533/1695 train_time:51409ms step_avg:96.45ms -step:534/1695 train_time:51504ms step_avg:96.45ms -step:535/1695 train_time:51596ms step_avg:96.44ms -step:536/1695 train_time:51689ms step_avg:96.43ms -step:537/1695 train_time:51782ms step_avg:96.43ms -step:538/1695 train_time:51875ms step_avg:96.42ms -step:539/1695 train_time:51969ms step_avg:96.42ms -step:540/1695 train_time:52063ms step_avg:96.41ms -step:541/1695 train_time:52158ms step_avg:96.41ms -step:542/1695 train_time:52252ms step_avg:96.41ms -step:543/1695 train_time:52348ms step_avg:96.40ms -step:544/1695 train_time:52442ms step_avg:96.40ms -step:545/1695 train_time:52536ms step_avg:96.40ms -step:546/1695 train_time:52629ms step_avg:96.39ms -step:547/1695 train_time:52723ms step_avg:96.39ms -step:548/1695 train_time:52816ms step_avg:96.38ms -step:549/1695 train_time:52909ms step_avg:96.37ms -step:550/1695 train_time:53003ms step_avg:96.37ms -step:551/1695 train_time:53097ms step_avg:96.36ms -step:552/1695 train_time:53192ms step_avg:96.36ms -step:553/1695 train_time:53287ms step_avg:96.36ms -step:554/1695 train_time:53382ms step_avg:96.36ms -step:555/1695 train_time:53476ms step_avg:96.35ms -step:556/1695 train_time:53569ms step_avg:96.35ms -step:557/1695 train_time:53663ms step_avg:96.34ms -step:558/1695 train_time:53756ms step_avg:96.34ms -step:559/1695 train_time:53849ms step_avg:96.33ms -step:560/1695 train_time:53943ms step_avg:96.33ms -step:561/1695 train_time:54037ms step_avg:96.32ms -step:562/1695 train_time:54131ms step_avg:96.32ms -step:563/1695 train_time:54226ms step_avg:96.32ms -step:564/1695 train_time:54321ms step_avg:96.31ms -step:565/1695 train_time:54414ms step_avg:96.31ms -step:566/1695 train_time:54508ms step_avg:96.30ms -step:567/1695 train_time:54602ms step_avg:96.30ms -step:568/1695 train_time:54697ms step_avg:96.30ms -step:569/1695 train_time:54792ms step_avg:96.30ms -step:570/1695 train_time:54889ms step_avg:96.30ms -step:571/1695 train_time:54986ms step_avg:96.30ms -step:572/1695 train_time:55083ms step_avg:96.30ms -step:573/1695 train_time:55180ms step_avg:96.30ms -step:574/1695 train_time:55276ms step_avg:96.30ms -step:575/1695 train_time:55373ms step_avg:96.30ms -step:576/1695 train_time:55469ms step_avg:96.30ms -step:577/1695 train_time:55565ms step_avg:96.30ms -step:578/1695 train_time:55662ms step_avg:96.30ms -step:579/1695 train_time:55757ms step_avg:96.30ms -step:580/1695 train_time:55852ms step_avg:96.30ms -step:581/1695 train_time:55948ms step_avg:96.30ms -step:582/1695 train_time:56044ms step_avg:96.30ms -step:583/1695 train_time:56140ms step_avg:96.30ms -step:584/1695 train_time:56236ms step_avg:96.29ms -step:585/1695 train_time:56331ms step_avg:96.29ms -step:586/1695 train_time:56427ms step_avg:96.29ms -step:587/1695 train_time:56523ms step_avg:96.29ms -step:588/1695 train_time:56621ms step_avg:96.29ms -step:589/1695 train_time:56716ms step_avg:96.29ms -step:590/1695 train_time:56812ms step_avg:96.29ms -step:591/1695 train_time:56908ms step_avg:96.29ms -step:592/1695 train_time:57005ms step_avg:96.29ms -step:593/1695 train_time:57101ms step_avg:96.29ms -step:594/1695 train_time:57196ms step_avg:96.29ms -step:595/1695 train_time:57291ms step_avg:96.29ms -step:596/1695 train_time:57388ms step_avg:96.29ms -step:597/1695 train_time:57484ms step_avg:96.29ms -step:598/1695 train_time:57581ms step_avg:96.29ms -step:599/1695 train_time:57678ms step_avg:96.29ms -step:600/1695 train_time:57773ms step_avg:96.29ms -step:601/1695 train_time:57869ms step_avg:96.29ms -step:602/1695 train_time:57964ms step_avg:96.29ms -step:603/1695 train_time:58060ms step_avg:96.28ms -step:604/1695 train_time:58155ms step_avg:96.28ms -step:605/1695 train_time:58251ms step_avg:96.28ms -step:606/1695 train_time:58348ms step_avg:96.28ms -step:607/1695 train_time:58444ms step_avg:96.28ms -step:608/1695 train_time:58541ms step_avg:96.28ms -step:609/1695 train_time:58636ms step_avg:96.28ms -step:610/1695 train_time:58732ms step_avg:96.28ms -step:611/1695 train_time:58827ms step_avg:96.28ms -step:612/1695 train_time:58923ms step_avg:96.28ms -step:613/1695 train_time:59020ms step_avg:96.28ms -step:614/1695 train_time:59116ms step_avg:96.28ms -step:615/1695 train_time:59211ms step_avg:96.28ms -step:616/1695 train_time:59307ms step_avg:96.28ms -step:617/1695 train_time:59404ms step_avg:96.28ms -step:618/1695 train_time:59499ms step_avg:96.28ms -step:619/1695 train_time:59595ms step_avg:96.28ms -step:620/1695 train_time:59691ms step_avg:96.28ms -step:621/1695 train_time:59788ms step_avg:96.28ms -step:622/1695 train_time:59884ms step_avg:96.28ms -step:623/1695 train_time:59982ms step_avg:96.28ms -step:624/1695 train_time:60078ms step_avg:96.28ms -step:625/1695 train_time:60173ms step_avg:96.28ms -step:625/1695 val_loss:3.6195 train_time:60266ms step_avg:96.43ms -step:626/1695 train_time:60290ms step_avg:96.31ms -step:627/1695 train_time:60370ms step_avg:96.28ms -step:628/1695 train_time:60467ms step_avg:96.29ms -step:629/1695 train_time:60563ms step_avg:96.28ms -step:630/1695 train_time:60658ms step_avg:96.28ms -step:631/1695 train_time:60753ms step_avg:96.28ms -step:632/1695 train_time:60847ms step_avg:96.28ms -step:633/1695 train_time:60943ms step_avg:96.28ms -step:634/1695 train_time:61038ms step_avg:96.27ms -step:635/1695 train_time:61133ms step_avg:96.27ms -step:636/1695 train_time:61232ms step_avg:96.28ms -step:637/1695 train_time:61331ms step_avg:96.28ms -step:638/1695 train_time:61426ms step_avg:96.28ms -step:639/1695 train_time:61522ms step_avg:96.28ms -step:640/1695 train_time:61617ms step_avg:96.28ms -step:641/1695 train_time:61712ms step_avg:96.28ms -step:642/1695 train_time:61808ms step_avg:96.27ms -step:643/1695 train_time:61902ms step_avg:96.27ms -step:644/1695 train_time:61997ms step_avg:96.27ms -step:645/1695 train_time:62092ms step_avg:96.27ms -step:646/1695 train_time:62188ms step_avg:96.27ms -step:647/1695 train_time:62284ms step_avg:96.27ms -step:648/1695 train_time:62382ms step_avg:96.27ms -step:649/1695 train_time:62478ms step_avg:96.27ms -step:650/1695 train_time:62575ms step_avg:96.27ms -step:651/1695 train_time:62671ms step_avg:96.27ms -step:652/1695 train_time:62767ms step_avg:96.27ms -step:653/1695 train_time:62862ms step_avg:96.27ms -step:654/1695 train_time:62957ms step_avg:96.27ms -step:655/1695 train_time:63053ms step_avg:96.26ms -step:656/1695 train_time:63148ms step_avg:96.26ms -step:657/1695 train_time:63245ms step_avg:96.26ms -step:658/1695 train_time:63341ms step_avg:96.26ms -step:659/1695 train_time:63438ms step_avg:96.26ms -step:660/1695 train_time:63535ms step_avg:96.27ms -step:661/1695 train_time:63632ms step_avg:96.27ms -step:662/1695 train_time:63727ms step_avg:96.26ms -step:663/1695 train_time:63822ms step_avg:96.26ms -step:664/1695 train_time:63918ms step_avg:96.26ms -step:665/1695 train_time:64014ms step_avg:96.26ms -step:666/1695 train_time:64111ms step_avg:96.26ms -step:667/1695 train_time:64207ms step_avg:96.26ms -step:668/1695 train_time:64303ms step_avg:96.26ms -step:669/1695 train_time:64400ms step_avg:96.26ms -step:670/1695 train_time:64497ms step_avg:96.26ms -step:671/1695 train_time:64594ms step_avg:96.27ms -step:672/1695 train_time:64691ms step_avg:96.27ms -step:673/1695 train_time:64786ms step_avg:96.26ms -step:674/1695 train_time:64881ms step_avg:96.26ms -step:675/1695 train_time:64977ms step_avg:96.26ms -step:676/1695 train_time:65073ms step_avg:96.26ms -step:677/1695 train_time:65170ms step_avg:96.26ms -step:678/1695 train_time:65266ms step_avg:96.26ms -step:679/1695 train_time:65361ms step_avg:96.26ms -step:680/1695 train_time:65458ms step_avg:96.26ms -step:681/1695 train_time:65554ms step_avg:96.26ms -step:682/1695 train_time:65651ms step_avg:96.26ms -step:683/1695 train_time:65747ms step_avg:96.26ms -step:684/1695 train_time:65843ms step_avg:96.26ms -step:685/1695 train_time:65939ms step_avg:96.26ms -step:686/1695 train_time:66034ms step_avg:96.26ms -step:687/1695 train_time:66129ms step_avg:96.26ms -step:688/1695 train_time:66225ms step_avg:96.26ms -step:689/1695 train_time:66320ms step_avg:96.26ms -step:690/1695 train_time:66417ms step_avg:96.26ms -step:691/1695 train_time:66874ms step_avg:96.78ms -step:692/1695 train_time:66944ms step_avg:96.74ms -step:693/1695 train_time:67039ms step_avg:96.74ms -step:694/1695 train_time:67134ms step_avg:96.74ms -step:695/1695 train_time:67229ms step_avg:96.73ms -step:696/1695 train_time:67323ms step_avg:96.73ms -step:697/1695 train_time:67418ms step_avg:96.73ms -step:698/1695 train_time:67513ms step_avg:96.72ms -step:699/1695 train_time:67607ms step_avg:96.72ms -step:700/1695 train_time:67702ms step_avg:96.72ms -step:701/1695 train_time:67802ms step_avg:96.72ms -step:702/1695 train_time:67902ms step_avg:96.73ms -step:703/1695 train_time:68000ms step_avg:96.73ms -step:704/1695 train_time:68096ms step_avg:96.73ms -step:705/1695 train_time:68192ms step_avg:96.73ms -step:706/1695 train_time:68287ms step_avg:96.72ms -step:707/1695 train_time:68382ms step_avg:96.72ms -step:708/1695 train_time:68478ms step_avg:96.72ms -step:709/1695 train_time:68574ms step_avg:96.72ms -step:710/1695 train_time:68670ms step_avg:96.72ms -step:711/1695 train_time:68765ms step_avg:96.72ms -step:712/1695 train_time:68862ms step_avg:96.72ms -step:713/1695 train_time:68959ms step_avg:96.72ms -step:714/1695 train_time:69056ms step_avg:96.72ms -step:715/1695 train_time:69153ms step_avg:96.72ms -step:716/1695 train_time:69250ms step_avg:96.72ms -step:717/1695 train_time:69344ms step_avg:96.71ms -step:718/1695 train_time:69439ms step_avg:96.71ms -step:719/1695 train_time:69535ms step_avg:96.71ms -step:720/1695 train_time:69631ms step_avg:96.71ms -step:721/1695 train_time:69726ms step_avg:96.71ms -step:722/1695 train_time:69822ms step_avg:96.71ms -step:723/1695 train_time:69918ms step_avg:96.71ms -step:724/1695 train_time:70015ms step_avg:96.71ms -step:725/1695 train_time:70113ms step_avg:96.71ms -step:726/1695 train_time:70210ms step_avg:96.71ms -step:727/1695 train_time:70305ms step_avg:96.71ms -step:728/1695 train_time:70400ms step_avg:96.70ms -step:729/1695 train_time:70496ms step_avg:96.70ms -step:730/1695 train_time:70591ms step_avg:96.70ms -step:731/1695 train_time:70686ms step_avg:96.70ms -step:732/1695 train_time:70782ms step_avg:96.70ms -step:733/1695 train_time:70878ms step_avg:96.70ms -step:734/1695 train_time:70974ms step_avg:96.69ms -step:735/1695 train_time:71069ms step_avg:96.69ms -step:736/1695 train_time:71166ms step_avg:96.69ms -step:737/1695 train_time:71262ms step_avg:96.69ms -step:738/1695 train_time:71358ms step_avg:96.69ms -step:739/1695 train_time:71454ms step_avg:96.69ms -step:740/1695 train_time:71549ms step_avg:96.69ms -step:741/1695 train_time:71644ms step_avg:96.69ms -step:742/1695 train_time:71739ms step_avg:96.68ms -step:743/1695 train_time:71836ms step_avg:96.68ms -step:744/1695 train_time:71931ms step_avg:96.68ms -step:745/1695 train_time:72027ms step_avg:96.68ms -step:746/1695 train_time:72123ms step_avg:96.68ms -step:747/1695 train_time:72220ms step_avg:96.68ms -step:748/1695 train_time:72317ms step_avg:96.68ms -step:749/1695 train_time:72413ms step_avg:96.68ms -step:750/1695 train_time:72510ms step_avg:96.68ms -step:750/1695 val_loss:3.5686 train_time:72604ms step_avg:96.81ms -step:751/1695 train_time:72630ms step_avg:96.71ms -step:752/1695 train_time:72710ms step_avg:96.69ms -step:753/1695 train_time:72807ms step_avg:96.69ms -step:754/1695 train_time:72902ms step_avg:96.69ms -step:755/1695 train_time:72998ms step_avg:96.69ms -step:756/1695 train_time:73092ms step_avg:96.68ms -step:757/1695 train_time:73186ms step_avg:96.68ms -step:758/1695 train_time:73281ms step_avg:96.68ms -step:759/1695 train_time:73376ms step_avg:96.67ms -step:760/1695 train_time:73470ms step_avg:96.67ms -step:761/1695 train_time:73566ms step_avg:96.67ms -step:762/1695 train_time:73665ms step_avg:96.67ms -step:763/1695 train_time:73763ms step_avg:96.67ms -step:764/1695 train_time:73859ms step_avg:96.67ms -step:765/1695 train_time:73955ms step_avg:96.67ms -step:766/1695 train_time:74051ms step_avg:96.67ms -step:767/1695 train_time:74146ms step_avg:96.67ms -step:768/1695 train_time:74241ms step_avg:96.67ms -step:769/1695 train_time:74336ms step_avg:96.67ms -step:770/1695 train_time:74430ms step_avg:96.66ms -step:771/1695 train_time:74525ms step_avg:96.66ms -step:772/1695 train_time:74622ms step_avg:96.66ms -step:773/1695 train_time:74719ms step_avg:96.66ms -step:774/1695 train_time:74817ms step_avg:96.66ms -step:775/1695 train_time:74913ms step_avg:96.66ms -step:776/1695 train_time:75008ms step_avg:96.66ms -step:777/1695 train_time:75104ms step_avg:96.66ms -step:778/1695 train_time:75199ms step_avg:96.66ms -step:779/1695 train_time:75294ms step_avg:96.65ms -step:780/1695 train_time:75389ms step_avg:96.65ms -step:781/1695 train_time:75484ms step_avg:96.65ms -step:782/1695 train_time:75580ms step_avg:96.65ms -step:783/1695 train_time:75677ms step_avg:96.65ms -step:784/1695 train_time:75773ms step_avg:96.65ms -step:785/1695 train_time:75870ms step_avg:96.65ms -step:786/1695 train_time:75965ms step_avg:96.65ms -step:787/1695 train_time:76061ms step_avg:96.65ms -step:788/1695 train_time:76156ms step_avg:96.64ms -step:789/1695 train_time:76251ms step_avg:96.64ms -step:790/1695 train_time:76346ms step_avg:96.64ms -step:791/1695 train_time:76441ms step_avg:96.64ms -step:792/1695 train_time:76538ms step_avg:96.64ms -step:793/1695 train_time:76634ms step_avg:96.64ms -step:794/1695 train_time:76730ms step_avg:96.64ms -step:795/1695 train_time:76825ms step_avg:96.64ms -step:796/1695 train_time:76921ms step_avg:96.63ms -step:797/1695 train_time:77017ms step_avg:96.63ms -step:798/1695 train_time:77113ms step_avg:96.63ms -step:799/1695 train_time:77207ms step_avg:96.63ms -step:800/1695 train_time:77303ms step_avg:96.63ms -step:801/1695 train_time:77398ms step_avg:96.63ms -step:802/1695 train_time:77493ms step_avg:96.62ms -step:803/1695 train_time:77588ms step_avg:96.62ms -step:804/1695 train_time:77684ms step_avg:96.62ms -step:805/1695 train_time:77782ms step_avg:96.62ms -step:806/1695 train_time:77879ms step_avg:96.62ms -step:807/1695 train_time:77976ms step_avg:96.62ms -step:808/1695 train_time:78073ms step_avg:96.62ms -step:809/1695 train_time:78168ms step_avg:96.62ms -step:810/1695 train_time:78262ms step_avg:96.62ms -step:811/1695 train_time:78357ms step_avg:96.62ms -step:812/1695 train_time:78453ms step_avg:96.62ms -step:813/1695 train_time:78547ms step_avg:96.61ms -step:814/1695 train_time:78643ms step_avg:96.61ms -step:815/1695 train_time:78739ms step_avg:96.61ms -step:816/1695 train_time:78836ms step_avg:96.61ms -step:817/1695 train_time:78932ms step_avg:96.61ms -step:818/1695 train_time:79028ms step_avg:96.61ms -step:819/1695 train_time:79124ms step_avg:96.61ms -step:820/1695 train_time:79219ms step_avg:96.61ms -step:821/1695 train_time:79315ms step_avg:96.61ms -step:822/1695 train_time:79410ms step_avg:96.61ms -step:823/1695 train_time:79505ms step_avg:96.60ms -step:824/1695 train_time:79600ms step_avg:96.60ms -step:825/1695 train_time:79696ms step_avg:96.60ms -step:826/1695 train_time:79791ms step_avg:96.60ms -step:827/1695 train_time:79887ms step_avg:96.60ms -step:828/1695 train_time:79984ms step_avg:96.60ms -step:829/1695 train_time:80081ms step_avg:96.60ms -step:830/1695 train_time:80178ms step_avg:96.60ms -step:831/1695 train_time:80273ms step_avg:96.60ms -step:832/1695 train_time:80368ms step_avg:96.60ms -step:833/1695 train_time:80463ms step_avg:96.59ms -step:834/1695 train_time:80559ms step_avg:96.59ms -step:835/1695 train_time:80656ms step_avg:96.59ms -step:836/1695 train_time:80752ms step_avg:96.59ms -step:837/1695 train_time:80847ms step_avg:96.59ms -step:838/1695 train_time:80942ms step_avg:96.59ms -step:839/1695 train_time:81038ms step_avg:96.59ms -step:840/1695 train_time:81134ms step_avg:96.59ms -step:841/1695 train_time:81229ms step_avg:96.59ms -step:842/1695 train_time:81324ms step_avg:96.58ms -step:843/1695 train_time:81420ms step_avg:96.58ms -step:844/1695 train_time:81516ms step_avg:96.58ms -step:845/1695 train_time:81612ms step_avg:96.58ms -step:846/1695 train_time:81707ms step_avg:96.58ms -step:847/1695 train_time:81802ms step_avg:96.58ms -step:848/1695 train_time:81899ms step_avg:96.58ms -step:849/1695 train_time:81994ms step_avg:96.58ms -step:850/1695 train_time:82090ms step_avg:96.58ms -step:851/1695 train_time:82186ms step_avg:96.58ms -step:852/1695 train_time:82281ms step_avg:96.57ms -step:853/1695 train_time:82377ms step_avg:96.57ms -step:854/1695 train_time:82473ms step_avg:96.57ms -step:855/1695 train_time:82569ms step_avg:96.57ms -step:856/1695 train_time:82664ms step_avg:96.57ms -step:857/1695 train_time:82760ms step_avg:96.57ms -step:858/1695 train_time:82855ms step_avg:96.57ms -step:859/1695 train_time:82952ms step_avg:96.57ms -step:860/1695 train_time:83047ms step_avg:96.57ms -step:861/1695 train_time:83143ms step_avg:96.57ms -step:862/1695 train_time:83239ms step_avg:96.56ms -step:863/1695 train_time:83566ms step_avg:96.83ms -step:864/1695 train_time:83759ms step_avg:96.94ms -step:865/1695 train_time:83853ms step_avg:96.94ms -step:866/1695 train_time:83948ms step_avg:96.94ms -step:867/1695 train_time:84042ms step_avg:96.93ms -step:868/1695 train_time:84138ms step_avg:96.93ms -step:869/1695 train_time:84233ms step_avg:96.93ms -step:870/1695 train_time:84327ms step_avg:96.93ms -step:871/1695 train_time:84421ms step_avg:96.92ms -step:872/1695 train_time:84516ms step_avg:96.92ms -step:873/1695 train_time:84616ms step_avg:96.93ms -step:874/1695 train_time:84714ms step_avg:96.93ms -step:875/1695 train_time:84811ms step_avg:96.93ms -step:875/1695 val_loss:3.5270 train_time:84905ms step_avg:97.03ms -step:876/1695 train_time:84930ms step_avg:96.95ms -step:877/1695 train_time:85012ms step_avg:96.94ms -step:878/1695 train_time:85111ms step_avg:96.94ms -step:879/1695 train_time:85209ms step_avg:96.94ms -step:880/1695 train_time:85304ms step_avg:96.94ms -step:881/1695 train_time:85400ms step_avg:96.94ms -step:882/1695 train_time:85494ms step_avg:96.93ms -step:883/1695 train_time:85589ms step_avg:96.93ms -step:884/1695 train_time:85685ms step_avg:96.93ms -step:885/1695 train_time:85780ms step_avg:96.93ms -step:886/1695 train_time:85876ms step_avg:96.93ms -step:887/1695 train_time:85973ms step_avg:96.93ms -step:888/1695 train_time:86071ms step_avg:96.93ms -step:889/1695 train_time:86170ms step_avg:96.93ms -step:890/1695 train_time:86267ms step_avg:96.93ms -step:891/1695 train_time:86364ms step_avg:96.93ms -step:892/1695 train_time:86459ms step_avg:96.93ms -step:893/1695 train_time:86554ms step_avg:96.92ms -step:894/1695 train_time:86649ms step_avg:96.92ms -step:895/1695 train_time:86745ms step_avg:96.92ms -step:896/1695 train_time:86842ms step_avg:96.92ms -step:897/1695 train_time:86938ms step_avg:96.92ms -step:898/1695 train_time:87034ms step_avg:96.92ms -step:899/1695 train_time:87131ms step_avg:96.92ms -step:900/1695 train_time:87229ms step_avg:96.92ms -step:901/1695 train_time:87325ms step_avg:96.92ms -step:902/1695 train_time:87422ms step_avg:96.92ms -step:903/1695 train_time:87517ms step_avg:96.92ms -step:904/1695 train_time:87612ms step_avg:96.92ms -step:905/1695 train_time:87708ms step_avg:96.91ms -step:906/1695 train_time:87804ms step_avg:96.91ms -step:907/1695 train_time:87901ms step_avg:96.91ms -step:908/1695 train_time:87997ms step_avg:96.91ms -step:909/1695 train_time:88092ms step_avg:96.91ms -step:910/1695 train_time:88188ms step_avg:96.91ms -step:911/1695 train_time:88285ms step_avg:96.91ms -step:912/1695 train_time:88381ms step_avg:96.91ms -step:913/1695 train_time:88477ms step_avg:96.91ms -step:914/1695 train_time:88572ms step_avg:96.91ms -step:915/1695 train_time:88667ms step_avg:96.90ms -step:916/1695 train_time:88763ms step_avg:96.90ms -step:917/1695 train_time:88860ms step_avg:96.90ms -step:918/1695 train_time:88955ms step_avg:96.90ms -step:919/1695 train_time:89050ms step_avg:96.90ms -step:920/1695 train_time:89147ms step_avg:96.90ms -step:921/1695 train_time:89245ms step_avg:96.90ms -step:922/1695 train_time:89341ms step_avg:96.90ms -step:923/1695 train_time:89436ms step_avg:96.90ms -step:924/1695 train_time:89531ms step_avg:96.90ms -step:925/1695 train_time:89627ms step_avg:96.89ms -step:926/1695 train_time:89724ms step_avg:96.89ms -step:927/1695 train_time:89821ms step_avg:96.89ms -step:928/1695 train_time:89916ms step_avg:96.89ms -step:929/1695 train_time:90011ms step_avg:96.89ms -step:930/1695 train_time:90108ms step_avg:96.89ms -step:931/1695 train_time:90205ms step_avg:96.89ms -step:932/1695 train_time:90302ms step_avg:96.89ms -step:933/1695 train_time:90398ms step_avg:96.89ms -step:934/1695 train_time:90493ms step_avg:96.89ms -step:935/1695 train_time:90589ms step_avg:96.89ms -step:936/1695 train_time:90685ms step_avg:96.89ms -step:937/1695 train_time:90782ms step_avg:96.89ms -step:938/1695 train_time:90877ms step_avg:96.88ms -step:939/1695 train_time:90973ms step_avg:96.88ms -step:940/1695 train_time:91069ms step_avg:96.88ms -step:941/1695 train_time:91167ms step_avg:96.88ms -step:942/1695 train_time:91264ms step_avg:96.88ms -step:943/1695 train_time:91361ms step_avg:96.88ms -step:944/1695 train_time:91456ms step_avg:96.88ms -step:945/1695 train_time:91552ms step_avg:96.88ms -step:946/1695 train_time:91648ms step_avg:96.88ms -step:947/1695 train_time:91744ms step_avg:96.88ms -step:948/1695 train_time:91840ms step_avg:96.88ms -step:949/1695 train_time:91935ms step_avg:96.88ms -step:950/1695 train_time:92031ms step_avg:96.87ms -step:951/1695 train_time:92127ms step_avg:96.87ms -step:952/1695 train_time:92223ms step_avg:96.87ms -step:953/1695 train_time:92319ms step_avg:96.87ms -step:954/1695 train_time:92415ms step_avg:96.87ms -step:955/1695 train_time:92510ms step_avg:96.87ms -step:956/1695 train_time:92606ms step_avg:96.87ms -step:957/1695 train_time:92702ms step_avg:96.87ms -step:958/1695 train_time:92798ms step_avg:96.87ms -step:959/1695 train_time:92894ms step_avg:96.87ms -step:960/1695 train_time:92990ms step_avg:96.86ms -step:961/1695 train_time:93087ms step_avg:96.86ms -step:962/1695 train_time:93183ms step_avg:96.86ms -step:963/1695 train_time:93279ms step_avg:96.86ms -step:964/1695 train_time:93374ms step_avg:96.86ms -step:965/1695 train_time:93471ms step_avg:96.86ms -step:966/1695 train_time:93568ms step_avg:96.86ms -step:967/1695 train_time:93664ms step_avg:96.86ms -step:968/1695 train_time:93760ms step_avg:96.86ms -step:969/1695 train_time:93856ms step_avg:96.86ms -step:970/1695 train_time:93951ms step_avg:96.86ms -step:971/1695 train_time:94048ms step_avg:96.86ms -step:972/1695 train_time:94144ms step_avg:96.86ms -step:973/1695 train_time:94241ms step_avg:96.86ms -step:974/1695 train_time:94336ms step_avg:96.85ms -step:975/1695 train_time:94431ms step_avg:96.85ms -step:976/1695 train_time:94528ms step_avg:96.85ms -step:977/1695 train_time:94625ms step_avg:96.85ms -step:978/1695 train_time:94722ms step_avg:96.85ms -step:979/1695 train_time:94818ms step_avg:96.85ms -step:980/1695 train_time:94913ms step_avg:96.85ms -step:981/1695 train_time:95010ms step_avg:96.85ms -step:982/1695 train_time:95106ms step_avg:96.85ms -step:983/1695 train_time:95202ms step_avg:96.85ms -step:984/1695 train_time:95298ms step_avg:96.85ms -step:985/1695 train_time:95393ms step_avg:96.85ms -step:986/1695 train_time:95489ms step_avg:96.84ms -step:987/1695 train_time:95585ms step_avg:96.84ms -step:988/1695 train_time:95682ms step_avg:96.84ms -step:989/1695 train_time:95777ms step_avg:96.84ms -step:990/1695 train_time:95872ms step_avg:96.84ms -step:991/1695 train_time:95967ms step_avg:96.84ms -step:992/1695 train_time:96064ms step_avg:96.84ms -step:993/1695 train_time:96160ms step_avg:96.84ms -step:994/1695 train_time:96255ms step_avg:96.84ms -step:995/1695 train_time:96351ms step_avg:96.83ms -step:996/1695 train_time:96446ms step_avg:96.83ms -step:997/1695 train_time:96543ms step_avg:96.83ms -step:998/1695 train_time:96638ms step_avg:96.83ms -step:999/1695 train_time:96734ms step_avg:96.83ms -step:1000/1695 train_time:96830ms step_avg:96.83ms -step:1000/1695 val_loss:3.4844 train_time:96924ms step_avg:96.92ms -step:1001/1695 train_time:96949ms step_avg:96.85ms -step:1002/1695 train_time:97032ms step_avg:96.84ms -step:1003/1695 train_time:97130ms step_avg:96.84ms -step:1004/1695 train_time:97226ms step_avg:96.84ms -step:1005/1695 train_time:97322ms step_avg:96.84ms -step:1006/1695 train_time:97417ms step_avg:96.84ms -step:1007/1695 train_time:97512ms step_avg:96.83ms -step:1008/1695 train_time:97606ms step_avg:96.83ms -step:1009/1695 train_time:97702ms step_avg:96.83ms -step:1010/1695 train_time:97797ms step_avg:96.83ms -step:1011/1695 train_time:97893ms step_avg:96.83ms -step:1012/1695 train_time:97991ms step_avg:96.83ms -step:1013/1695 train_time:98089ms step_avg:96.83ms -step:1014/1695 train_time:98187ms step_avg:96.83ms -step:1015/1695 train_time:98284ms step_avg:96.83ms -step:1016/1695 train_time:98379ms step_avg:96.83ms -step:1017/1695 train_time:98474ms step_avg:96.83ms -step:1018/1695 train_time:98569ms step_avg:96.83ms -step:1019/1695 train_time:98665ms step_avg:96.83ms -step:1020/1695 train_time:98761ms step_avg:96.82ms -step:1021/1695 train_time:98857ms step_avg:96.82ms -step:1022/1695 train_time:98954ms step_avg:96.82ms -step:1023/1695 train_time:99050ms step_avg:96.82ms -step:1024/1695 train_time:99147ms step_avg:96.82ms -step:1025/1695 train_time:99243ms step_avg:96.82ms -step:1026/1695 train_time:99340ms step_avg:96.82ms -step:1027/1695 train_time:99436ms step_avg:96.82ms -step:1028/1695 train_time:99530ms step_avg:96.82ms -step:1029/1695 train_time:99625ms step_avg:96.82ms -step:1030/1695 train_time:99721ms step_avg:96.82ms -step:1031/1695 train_time:99818ms step_avg:96.82ms -step:1032/1695 train_time:99913ms step_avg:96.82ms -step:1033/1695 train_time:100009ms step_avg:96.81ms -step:1034/1695 train_time:100107ms step_avg:96.82ms -step:1035/1695 train_time:100204ms step_avg:96.82ms -step:1036/1695 train_time:100552ms step_avg:97.06ms -step:1037/1695 train_time:100724ms step_avg:97.13ms -step:1038/1695 train_time:100817ms step_avg:97.13ms -step:1039/1695 train_time:100912ms step_avg:97.12ms -step:1040/1695 train_time:101007ms step_avg:97.12ms -step:1041/1695 train_time:101101ms step_avg:97.12ms -step:1042/1695 train_time:101196ms step_avg:97.12ms -step:1043/1695 train_time:101290ms step_avg:97.11ms -step:1044/1695 train_time:101385ms step_avg:97.11ms -step:1045/1695 train_time:101480ms step_avg:97.11ms -step:1046/1695 train_time:101579ms step_avg:97.11ms -step:1047/1695 train_time:101680ms step_avg:97.12ms -step:1048/1695 train_time:101777ms step_avg:97.12ms -step:1049/1695 train_time:101873ms step_avg:97.11ms -step:1050/1695 train_time:101968ms step_avg:97.11ms -step:1051/1695 train_time:102063ms step_avg:97.11ms -step:1052/1695 train_time:102159ms step_avg:97.11ms -step:1053/1695 train_time:102253ms step_avg:97.11ms -step:1054/1695 train_time:102348ms step_avg:97.10ms -step:1055/1695 train_time:102443ms step_avg:97.10ms -step:1056/1695 train_time:102540ms step_avg:97.10ms -step:1057/1695 train_time:102637ms step_avg:97.10ms -step:1058/1695 train_time:102733ms step_avg:97.10ms -step:1059/1695 train_time:102829ms step_avg:97.10ms -step:1060/1695 train_time:102925ms step_avg:97.10ms -step:1061/1695 train_time:103021ms step_avg:97.10ms -step:1062/1695 train_time:103117ms step_avg:97.10ms -step:1063/1695 train_time:103212ms step_avg:97.09ms -step:1064/1695 train_time:103307ms step_avg:97.09ms -step:1065/1695 train_time:103402ms step_avg:97.09ms -step:1066/1695 train_time:103499ms step_avg:97.09ms -step:1067/1695 train_time:103595ms step_avg:97.09ms -step:1068/1695 train_time:103691ms step_avg:97.09ms -step:1069/1695 train_time:103789ms step_avg:97.09ms -step:1070/1695 train_time:103886ms step_avg:97.09ms -step:1071/1695 train_time:103983ms step_avg:97.09ms -step:1072/1695 train_time:104080ms step_avg:97.09ms -step:1073/1695 train_time:104176ms step_avg:97.09ms -step:1074/1695 train_time:104270ms step_avg:97.09ms -step:1075/1695 train_time:104366ms step_avg:97.08ms -step:1076/1695 train_time:104461ms step_avg:97.08ms -step:1077/1695 train_time:104558ms step_avg:97.08ms -step:1078/1695 train_time:104654ms step_avg:97.08ms -step:1079/1695 train_time:104750ms step_avg:97.08ms -step:1080/1695 train_time:104847ms step_avg:97.08ms -step:1081/1695 train_time:104946ms step_avg:97.08ms -step:1082/1695 train_time:105043ms step_avg:97.08ms -step:1083/1695 train_time:105140ms step_avg:97.08ms -step:1084/1695 train_time:105235ms step_avg:97.08ms -step:1085/1695 train_time:105330ms step_avg:97.08ms -step:1086/1695 train_time:105426ms step_avg:97.08ms -step:1087/1695 train_time:105523ms step_avg:97.08ms -step:1088/1695 train_time:105619ms step_avg:97.08ms -step:1089/1695 train_time:105715ms step_avg:97.08ms -step:1090/1695 train_time:105811ms step_avg:97.07ms -step:1091/1695 train_time:105907ms step_avg:97.07ms -step:1092/1695 train_time:106003ms step_avg:97.07ms -step:1093/1695 train_time:106099ms step_avg:97.07ms -step:1094/1695 train_time:106194ms step_avg:97.07ms -step:1095/1695 train_time:106289ms step_avg:97.07ms -step:1096/1695 train_time:106386ms step_avg:97.07ms -step:1097/1695 train_time:106482ms step_avg:97.07ms -step:1098/1695 train_time:106578ms step_avg:97.07ms -step:1099/1695 train_time:106673ms step_avg:97.06ms -step:1100/1695 train_time:106769ms step_avg:97.06ms -step:1101/1695 train_time:106866ms step_avg:97.06ms -step:1102/1695 train_time:106963ms step_avg:97.06ms -step:1103/1695 train_time:107060ms step_avg:97.06ms -step:1104/1695 train_time:107156ms step_avg:97.06ms -step:1105/1695 train_time:107251ms step_avg:97.06ms -step:1106/1695 train_time:107347ms step_avg:97.06ms -step:1107/1695 train_time:107443ms step_avg:97.06ms -step:1108/1695 train_time:107540ms step_avg:97.06ms -step:1109/1695 train_time:107636ms step_avg:97.06ms -step:1110/1695 train_time:107730ms step_avg:97.05ms -step:1111/1695 train_time:107826ms step_avg:97.05ms -step:1112/1695 train_time:107922ms step_avg:97.05ms -step:1113/1695 train_time:108019ms step_avg:97.05ms -step:1114/1695 train_time:108115ms step_avg:97.05ms -step:1115/1695 train_time:108211ms step_avg:97.05ms -step:1116/1695 train_time:108306ms step_avg:97.05ms -step:1117/1695 train_time:108403ms step_avg:97.05ms -step:1118/1695 train_time:108499ms step_avg:97.05ms -step:1119/1695 train_time:108594ms step_avg:97.05ms -step:1120/1695 train_time:108690ms step_avg:97.04ms -step:1121/1695 train_time:108786ms step_avg:97.04ms -step:1122/1695 train_time:108883ms step_avg:97.04ms -step:1123/1695 train_time:108979ms step_avg:97.04ms -step:1124/1695 train_time:109075ms step_avg:97.04ms -step:1125/1695 train_time:109170ms step_avg:97.04ms -step:1125/1695 val_loss:3.4368 train_time:109264ms step_avg:97.12ms -step:1126/1695 train_time:109288ms step_avg:97.06ms -step:1127/1695 train_time:109371ms step_avg:97.05ms -step:1128/1695 train_time:109469ms step_avg:97.05ms -step:1129/1695 train_time:109566ms step_avg:97.05ms -step:1130/1695 train_time:109662ms step_avg:97.05ms -step:1131/1695 train_time:109757ms step_avg:97.04ms -step:1132/1695 train_time:109852ms step_avg:97.04ms -step:1133/1695 train_time:109950ms step_avg:97.04ms -step:1134/1695 train_time:110047ms step_avg:97.04ms -step:1135/1695 train_time:110144ms step_avg:97.04ms -step:1136/1695 train_time:110243ms step_avg:97.04ms -step:1137/1695 train_time:110343ms step_avg:97.05ms -step:1138/1695 train_time:110442ms step_avg:97.05ms -step:1139/1695 train_time:110539ms step_avg:97.05ms -step:1140/1695 train_time:110636ms step_avg:97.05ms -step:1141/1695 train_time:110733ms step_avg:97.05ms -step:1142/1695 train_time:110829ms step_avg:97.05ms -step:1143/1695 train_time:110927ms step_avg:97.05ms -step:1144/1695 train_time:111024ms step_avg:97.05ms -step:1145/1695 train_time:111121ms step_avg:97.05ms -step:1146/1695 train_time:111220ms step_avg:97.05ms -step:1147/1695 train_time:111319ms step_avg:97.05ms -step:1148/1695 train_time:111417ms step_avg:97.05ms -step:1149/1695 train_time:111515ms step_avg:97.05ms -step:1150/1695 train_time:111613ms step_avg:97.05ms -step:1151/1695 train_time:111710ms step_avg:97.05ms -step:1152/1695 train_time:111807ms step_avg:97.05ms -step:1153/1695 train_time:111903ms step_avg:97.05ms -step:1154/1695 train_time:112000ms step_avg:97.05ms -step:1155/1695 train_time:112096ms step_avg:97.05ms -step:1156/1695 train_time:112194ms step_avg:97.05ms -step:1157/1695 train_time:112292ms step_avg:97.05ms -step:1158/1695 train_time:112392ms step_avg:97.06ms -step:1159/1695 train_time:112491ms step_avg:97.06ms -step:1160/1695 train_time:112592ms step_avg:97.06ms -step:1161/1695 train_time:112691ms step_avg:97.06ms -step:1162/1695 train_time:112789ms step_avg:97.06ms -step:1163/1695 train_time:112885ms step_avg:97.06ms -step:1164/1695 train_time:112983ms step_avg:97.06ms -step:1165/1695 train_time:113080ms step_avg:97.06ms -step:1166/1695 train_time:113177ms step_avg:97.06ms -step:1167/1695 train_time:113274ms step_avg:97.06ms -step:1168/1695 train_time:113372ms step_avg:97.06ms -step:1169/1695 train_time:113471ms step_avg:97.07ms -step:1170/1695 train_time:113571ms step_avg:97.07ms -step:1171/1695 train_time:113670ms step_avg:97.07ms -step:1172/1695 train_time:113769ms step_avg:97.07ms -step:1173/1695 train_time:113866ms step_avg:97.07ms -step:1174/1695 train_time:113964ms step_avg:97.07ms -step:1175/1695 train_time:114063ms step_avg:97.08ms -step:1176/1695 train_time:114161ms step_avg:97.08ms -step:1177/1695 train_time:114259ms step_avg:97.08ms -step:1178/1695 train_time:114356ms step_avg:97.08ms -step:1179/1695 train_time:114453ms step_avg:97.08ms -step:1180/1695 train_time:114551ms step_avg:97.08ms -step:1181/1695 train_time:114649ms step_avg:97.08ms -step:1182/1695 train_time:114746ms step_avg:97.08ms -step:1183/1695 train_time:114844ms step_avg:97.08ms -step:1184/1695 train_time:114942ms step_avg:97.08ms -step:1185/1695 train_time:115039ms step_avg:97.08ms -step:1186/1695 train_time:115136ms step_avg:97.08ms -step:1187/1695 train_time:115233ms step_avg:97.08ms -step:1188/1695 train_time:115331ms step_avg:97.08ms -step:1189/1695 train_time:115429ms step_avg:97.08ms -step:1190/1695 train_time:115527ms step_avg:97.08ms -step:1191/1695 train_time:115625ms step_avg:97.08ms -step:1192/1695 train_time:115723ms step_avg:97.08ms -step:1193/1695 train_time:115821ms step_avg:97.08ms -step:1194/1695 train_time:115918ms step_avg:97.08ms -step:1195/1695 train_time:116016ms step_avg:97.08ms -step:1196/1695 train_time:116114ms step_avg:97.09ms -step:1197/1695 train_time:116212ms step_avg:97.09ms -step:1198/1695 train_time:116310ms step_avg:97.09ms -step:1199/1695 train_time:116409ms step_avg:97.09ms -step:1200/1695 train_time:116509ms step_avg:97.09ms -step:1201/1695 train_time:116609ms step_avg:97.09ms -step:1202/1695 train_time:116708ms step_avg:97.09ms -step:1203/1695 train_time:116808ms step_avg:97.10ms -step:1204/1695 train_time:116906ms step_avg:97.10ms -step:1205/1695 train_time:117005ms step_avg:97.10ms -step:1206/1695 train_time:117103ms step_avg:97.10ms -step:1207/1695 train_time:117201ms step_avg:97.10ms -step:1208/1695 train_time:117548ms step_avg:97.31ms -step:1209/1695 train_time:117728ms step_avg:97.38ms -step:1210/1695 train_time:117823ms step_avg:97.37ms -step:1211/1695 train_time:117920ms step_avg:97.37ms -step:1212/1695 train_time:118016ms step_avg:97.37ms -step:1213/1695 train_time:118112ms step_avg:97.37ms -step:1214/1695 train_time:118209ms step_avg:97.37ms -step:1215/1695 train_time:118306ms step_avg:97.37ms -step:1216/1695 train_time:118402ms step_avg:97.37ms -step:1217/1695 train_time:118500ms step_avg:97.37ms -step:1218/1695 train_time:118604ms step_avg:97.38ms -step:1219/1695 train_time:118704ms step_avg:97.38ms -step:1220/1695 train_time:118801ms step_avg:97.38ms -step:1221/1695 train_time:118897ms step_avg:97.38ms -step:1222/1695 train_time:118994ms step_avg:97.38ms -step:1223/1695 train_time:119090ms step_avg:97.38ms -step:1224/1695 train_time:119187ms step_avg:97.38ms -step:1225/1695 train_time:119285ms step_avg:97.38ms -step:1226/1695 train_time:119382ms step_avg:97.38ms -step:1227/1695 train_time:119480ms step_avg:97.38ms -step:1228/1695 train_time:119579ms step_avg:97.38ms -step:1229/1695 train_time:119678ms step_avg:97.38ms -step:1230/1695 train_time:119776ms step_avg:97.38ms -step:1231/1695 train_time:119874ms step_avg:97.38ms -step:1232/1695 train_time:119971ms step_avg:97.38ms -step:1233/1695 train_time:120068ms step_avg:97.38ms -step:1234/1695 train_time:120166ms step_avg:97.38ms -step:1235/1695 train_time:120263ms step_avg:97.38ms -step:1236/1695 train_time:120360ms step_avg:97.38ms -step:1237/1695 train_time:120457ms step_avg:97.38ms -step:1238/1695 train_time:120555ms step_avg:97.38ms -step:1239/1695 train_time:120654ms step_avg:97.38ms -step:1240/1695 train_time:120752ms step_avg:97.38ms -step:1241/1695 train_time:120851ms step_avg:97.38ms -step:1242/1695 train_time:120950ms step_avg:97.38ms -step:1243/1695 train_time:121048ms step_avg:97.38ms -step:1244/1695 train_time:121145ms step_avg:97.38ms -step:1245/1695 train_time:121243ms step_avg:97.38ms -step:1246/1695 train_time:121340ms step_avg:97.38ms -step:1247/1695 train_time:121437ms step_avg:97.38ms -step:1248/1695 train_time:121534ms step_avg:97.38ms -step:1249/1695 train_time:121632ms step_avg:97.38ms -step:1250/1695 train_time:121731ms step_avg:97.38ms -step:1250/1695 val_loss:3.3897 train_time:121827ms step_avg:97.46ms -step:1251/1695 train_time:121854ms step_avg:97.40ms -step:1252/1695 train_time:121931ms step_avg:97.39ms -step:1253/1695 train_time:122027ms step_avg:97.39ms -step:1254/1695 train_time:122123ms step_avg:97.39ms -step:1255/1695 train_time:122220ms step_avg:97.39ms -step:1256/1695 train_time:122317ms step_avg:97.39ms -step:1257/1695 train_time:122414ms step_avg:97.39ms -step:1258/1695 train_time:122510ms step_avg:97.38ms -step:1259/1695 train_time:122606ms step_avg:97.38ms -step:1260/1695 train_time:122702ms step_avg:97.38ms -step:1261/1695 train_time:122805ms step_avg:97.39ms -step:1262/1695 train_time:122904ms step_avg:97.39ms -step:1263/1695 train_time:123002ms step_avg:97.39ms -step:1264/1695 train_time:123099ms step_avg:97.39ms -step:1265/1695 train_time:123196ms step_avg:97.39ms -step:1266/1695 train_time:123292ms step_avg:97.39ms -step:1267/1695 train_time:123390ms step_avg:97.39ms -step:1268/1695 train_time:123486ms step_avg:97.39ms -step:1269/1695 train_time:123583ms step_avg:97.39ms -step:1270/1695 train_time:123681ms step_avg:97.39ms -step:1271/1695 train_time:123780ms step_avg:97.39ms -step:1272/1695 train_time:123879ms step_avg:97.39ms -step:1273/1695 train_time:123978ms step_avg:97.39ms -step:1274/1695 train_time:124078ms step_avg:97.39ms -step:1275/1695 train_time:124176ms step_avg:97.39ms -step:1276/1695 train_time:124275ms step_avg:97.39ms -step:1277/1695 train_time:124372ms step_avg:97.39ms -step:1278/1695 train_time:124470ms step_avg:97.39ms -step:1279/1695 train_time:124567ms step_avg:97.39ms -step:1280/1695 train_time:124664ms step_avg:97.39ms -step:1281/1695 train_time:124761ms step_avg:97.39ms -step:1282/1695 train_time:124859ms step_avg:97.39ms -step:1283/1695 train_time:124959ms step_avg:97.40ms -step:1284/1695 train_time:125057ms step_avg:97.40ms -step:1285/1695 train_time:125156ms step_avg:97.40ms -step:1286/1695 train_time:125254ms step_avg:97.40ms -step:1287/1695 train_time:125351ms step_avg:97.40ms -step:1288/1695 train_time:125449ms step_avg:97.40ms -step:1289/1695 train_time:125546ms step_avg:97.40ms -step:1290/1695 train_time:125644ms step_avg:97.40ms -step:1291/1695 train_time:125741ms step_avg:97.40ms -step:1292/1695 train_time:125839ms step_avg:97.40ms -step:1293/1695 train_time:125938ms step_avg:97.40ms -step:1294/1695 train_time:126037ms step_avg:97.40ms -step:1295/1695 train_time:126136ms step_avg:97.40ms -step:1296/1695 train_time:126234ms step_avg:97.40ms -step:1297/1695 train_time:126333ms step_avg:97.40ms -step:1298/1695 train_time:126431ms step_avg:97.40ms -step:1299/1695 train_time:126529ms step_avg:97.41ms -step:1300/1695 train_time:126628ms step_avg:97.41ms -step:1301/1695 train_time:126726ms step_avg:97.41ms -step:1302/1695 train_time:126823ms step_avg:97.41ms -step:1303/1695 train_time:126921ms step_avg:97.41ms -step:1304/1695 train_time:127018ms step_avg:97.41ms -step:1305/1695 train_time:127117ms step_avg:97.41ms -step:1306/1695 train_time:127216ms step_avg:97.41ms -step:1307/1695 train_time:127314ms step_avg:97.41ms -step:1308/1695 train_time:127412ms step_avg:97.41ms -step:1309/1695 train_time:127509ms step_avg:97.41ms -step:1310/1695 train_time:127608ms step_avg:97.41ms -step:1311/1695 train_time:127705ms step_avg:97.41ms -step:1312/1695 train_time:127802ms step_avg:97.41ms -step:1313/1695 train_time:127899ms step_avg:97.41ms -step:1314/1695 train_time:127996ms step_avg:97.41ms -step:1315/1695 train_time:128095ms step_avg:97.41ms -step:1316/1695 train_time:128193ms step_avg:97.41ms -step:1317/1695 train_time:128291ms step_avg:97.41ms -step:1318/1695 train_time:128389ms step_avg:97.41ms -step:1319/1695 train_time:128485ms step_avg:97.41ms -step:1320/1695 train_time:128582ms step_avg:97.41ms -step:1321/1695 train_time:128680ms step_avg:97.41ms -step:1322/1695 train_time:128778ms step_avg:97.41ms -step:1323/1695 train_time:128876ms step_avg:97.41ms -step:1324/1695 train_time:128974ms step_avg:97.41ms -step:1325/1695 train_time:129072ms step_avg:97.41ms -step:1326/1695 train_time:129170ms step_avg:97.41ms -step:1327/1695 train_time:129268ms step_avg:97.41ms -step:1328/1695 train_time:129366ms step_avg:97.41ms -step:1329/1695 train_time:129463ms step_avg:97.41ms -step:1330/1695 train_time:129561ms step_avg:97.41ms -step:1331/1695 train_time:129659ms step_avg:97.41ms -step:1332/1695 train_time:129758ms step_avg:97.42ms -step:1333/1695 train_time:129857ms step_avg:97.42ms -step:1334/1695 train_time:129955ms step_avg:97.42ms -step:1335/1695 train_time:130053ms step_avg:97.42ms -step:1336/1695 train_time:130151ms step_avg:97.42ms -step:1337/1695 train_time:130248ms step_avg:97.42ms -step:1338/1695 train_time:130347ms step_avg:97.42ms -step:1339/1695 train_time:130444ms step_avg:97.42ms -step:1340/1695 train_time:130541ms step_avg:97.42ms -step:1341/1695 train_time:130639ms step_avg:97.42ms -step:1342/1695 train_time:130736ms step_avg:97.42ms -step:1343/1695 train_time:130835ms step_avg:97.42ms -step:1344/1695 train_time:130933ms step_avg:97.42ms -step:1345/1695 train_time:131030ms step_avg:97.42ms -step:1346/1695 train_time:131127ms step_avg:97.42ms -step:1347/1695 train_time:131224ms step_avg:97.42ms -step:1348/1695 train_time:131321ms step_avg:97.42ms -step:1349/1695 train_time:131419ms step_avg:97.42ms -step:1350/1695 train_time:131518ms step_avg:97.42ms -step:1351/1695 train_time:131615ms step_avg:97.42ms -step:1352/1695 train_time:131714ms step_avg:97.42ms -step:1353/1695 train_time:131813ms step_avg:97.42ms -step:1354/1695 train_time:131911ms step_avg:97.42ms -step:1355/1695 train_time:132009ms step_avg:97.42ms -step:1356/1695 train_time:132106ms step_avg:97.42ms -step:1357/1695 train_time:132203ms step_avg:97.42ms -step:1358/1695 train_time:132300ms step_avg:97.42ms -step:1359/1695 train_time:132398ms step_avg:97.42ms -step:1360/1695 train_time:132497ms step_avg:97.42ms -step:1361/1695 train_time:132595ms step_avg:97.42ms -step:1362/1695 train_time:132693ms step_avg:97.43ms -step:1363/1695 train_time:132792ms step_avg:97.43ms -step:1364/1695 train_time:132890ms step_avg:97.43ms -step:1365/1695 train_time:132988ms step_avg:97.43ms -step:1366/1695 train_time:133085ms step_avg:97.43ms -step:1367/1695 train_time:133182ms step_avg:97.43ms -step:1368/1695 train_time:133279ms step_avg:97.43ms -step:1369/1695 train_time:133377ms step_avg:97.43ms -step:1370/1695 train_time:133476ms step_avg:97.43ms -step:1371/1695 train_time:133574ms step_avg:97.43ms -step:1372/1695 train_time:133671ms step_avg:97.43ms -step:1373/1695 train_time:133769ms step_avg:97.43ms -step:1374/1695 train_time:133867ms step_avg:97.43ms -step:1375/1695 train_time:133964ms step_avg:97.43ms -step:1375/1695 val_loss:3.3507 train_time:134060ms step_avg:97.50ms -step:1376/1695 train_time:134085ms step_avg:97.45ms -step:1377/1695 train_time:134167ms step_avg:97.43ms -step:1378/1695 train_time:134266ms step_avg:97.44ms -step:1379/1695 train_time:134364ms step_avg:97.44ms -step:1380/1695 train_time:134461ms step_avg:97.44ms -step:1381/1695 train_time:134815ms step_avg:97.62ms -step:1382/1695 train_time:134984ms step_avg:97.67ms -step:1383/1695 train_time:135080ms step_avg:97.67ms -step:1384/1695 train_time:135176ms step_avg:97.67ms -step:1385/1695 train_time:135272ms step_avg:97.67ms -step:1386/1695 train_time:135369ms step_avg:97.67ms -step:1387/1695 train_time:135465ms step_avg:97.67ms -step:1388/1695 train_time:135562ms step_avg:97.67ms -step:1389/1695 train_time:135658ms step_avg:97.67ms -step:1390/1695 train_time:135756ms step_avg:97.67ms -step:1391/1695 train_time:135859ms step_avg:97.67ms -step:1392/1695 train_time:135961ms step_avg:97.67ms -step:1393/1695 train_time:136060ms step_avg:97.67ms -step:1394/1695 train_time:136156ms step_avg:97.67ms -step:1395/1695 train_time:136253ms step_avg:97.67ms -step:1396/1695 train_time:136350ms step_avg:97.67ms -step:1397/1695 train_time:136446ms step_avg:97.67ms -step:1398/1695 train_time:136542ms step_avg:97.67ms -step:1399/1695 train_time:136639ms step_avg:97.67ms -step:1400/1695 train_time:136736ms step_avg:97.67ms -step:1401/1695 train_time:136834ms step_avg:97.67ms -step:1402/1695 train_time:136933ms step_avg:97.67ms -step:1403/1695 train_time:137032ms step_avg:97.67ms -step:1404/1695 train_time:137131ms step_avg:97.67ms -step:1405/1695 train_time:137230ms step_avg:97.67ms -step:1406/1695 train_time:137328ms step_avg:97.67ms -step:1407/1695 train_time:137425ms step_avg:97.67ms -step:1408/1695 train_time:137522ms step_avg:97.67ms -step:1409/1695 train_time:137619ms step_avg:97.67ms -step:1410/1695 train_time:137716ms step_avg:97.67ms -step:1411/1695 train_time:137813ms step_avg:97.67ms -step:1412/1695 train_time:137912ms step_avg:97.67ms -step:1413/1695 train_time:138011ms step_avg:97.67ms -step:1414/1695 train_time:138111ms step_avg:97.67ms -step:1415/1695 train_time:138210ms step_avg:97.67ms -step:1416/1695 train_time:138308ms step_avg:97.67ms -step:1417/1695 train_time:138405ms step_avg:97.67ms -step:1418/1695 train_time:138502ms step_avg:97.67ms -step:1419/1695 train_time:138600ms step_avg:97.67ms -step:1420/1695 train_time:138698ms step_avg:97.67ms -step:1421/1695 train_time:138795ms step_avg:97.67ms -step:1422/1695 train_time:138892ms step_avg:97.67ms -step:1423/1695 train_time:138991ms step_avg:97.67ms -step:1424/1695 train_time:139089ms step_avg:97.67ms -step:1425/1695 train_time:139188ms step_avg:97.68ms -step:1426/1695 train_time:139285ms step_avg:97.68ms -step:1427/1695 train_time:139382ms step_avg:97.68ms -step:1428/1695 train_time:139480ms step_avg:97.68ms -step:1429/1695 train_time:139577ms step_avg:97.67ms -step:1430/1695 train_time:139673ms step_avg:97.67ms -step:1431/1695 train_time:139771ms step_avg:97.67ms -step:1432/1695 train_time:139869ms step_avg:97.67ms -step:1433/1695 train_time:139968ms step_avg:97.67ms -step:1434/1695 train_time:140067ms step_avg:97.68ms -step:1435/1695 train_time:140167ms step_avg:97.68ms -step:1436/1695 train_time:140266ms step_avg:97.68ms -step:1437/1695 train_time:140363ms step_avg:97.68ms -step:1438/1695 train_time:140461ms step_avg:97.68ms -step:1439/1695 train_time:140558ms step_avg:97.68ms -step:1440/1695 train_time:140655ms step_avg:97.68ms -step:1441/1695 train_time:140752ms step_avg:97.68ms -step:1442/1695 train_time:140849ms step_avg:97.68ms -step:1443/1695 train_time:140948ms step_avg:97.68ms -step:1444/1695 train_time:141045ms step_avg:97.68ms -step:1445/1695 train_time:141144ms step_avg:97.68ms -step:1446/1695 train_time:141242ms step_avg:97.68ms -step:1447/1695 train_time:141340ms step_avg:97.68ms -step:1448/1695 train_time:141438ms step_avg:97.68ms -step:1449/1695 train_time:141534ms step_avg:97.68ms -step:1450/1695 train_time:141631ms step_avg:97.68ms -step:1451/1695 train_time:141728ms step_avg:97.68ms -step:1452/1695 train_time:141826ms step_avg:97.68ms -step:1453/1695 train_time:141924ms step_avg:97.68ms -step:1454/1695 train_time:142021ms step_avg:97.68ms -step:1455/1695 train_time:142117ms step_avg:97.68ms -step:1456/1695 train_time:142216ms step_avg:97.68ms -step:1457/1695 train_time:142314ms step_avg:97.68ms -step:1458/1695 train_time:142413ms step_avg:97.68ms -step:1459/1695 train_time:142510ms step_avg:97.68ms -step:1460/1695 train_time:142608ms step_avg:97.68ms -step:1461/1695 train_time:142706ms step_avg:97.68ms -step:1462/1695 train_time:142803ms step_avg:97.68ms -step:1463/1695 train_time:142901ms step_avg:97.68ms -step:1464/1695 train_time:142999ms step_avg:97.68ms -step:1465/1695 train_time:143096ms step_avg:97.68ms -step:1466/1695 train_time:143194ms step_avg:97.68ms -step:1467/1695 train_time:143291ms step_avg:97.68ms -step:1468/1695 train_time:143389ms step_avg:97.68ms -step:1469/1695 train_time:143487ms step_avg:97.68ms -step:1470/1695 train_time:143585ms step_avg:97.68ms -step:1471/1695 train_time:143682ms step_avg:97.68ms -step:1472/1695 train_time:143779ms step_avg:97.68ms -step:1473/1695 train_time:143877ms step_avg:97.68ms -step:1474/1695 train_time:143974ms step_avg:97.68ms -step:1475/1695 train_time:144072ms step_avg:97.68ms -step:1476/1695 train_time:144169ms step_avg:97.68ms -step:1477/1695 train_time:144267ms step_avg:97.68ms -step:1478/1695 train_time:144365ms step_avg:97.68ms -step:1479/1695 train_time:144462ms step_avg:97.68ms -step:1480/1695 train_time:144559ms step_avg:97.68ms -step:1481/1695 train_time:144657ms step_avg:97.67ms -step:1482/1695 train_time:144754ms step_avg:97.67ms -step:1483/1695 train_time:144852ms step_avg:97.67ms -step:1484/1695 train_time:144949ms step_avg:97.67ms -step:1485/1695 train_time:145048ms step_avg:97.68ms -step:1486/1695 train_time:145146ms step_avg:97.68ms -step:1487/1695 train_time:145244ms step_avg:97.68ms -step:1488/1695 train_time:145341ms step_avg:97.68ms -step:1489/1695 train_time:145438ms step_avg:97.67ms -step:1490/1695 train_time:145535ms step_avg:97.67ms -step:1491/1695 train_time:145632ms step_avg:97.67ms -step:1492/1695 train_time:145730ms step_avg:97.67ms -step:1493/1695 train_time:145829ms step_avg:97.68ms -step:1494/1695 train_time:145927ms step_avg:97.68ms -step:1495/1695 train_time:146024ms step_avg:97.68ms -step:1496/1695 train_time:146123ms step_avg:97.68ms -step:1497/1695 train_time:146220ms step_avg:97.68ms -step:1498/1695 train_time:146317ms step_avg:97.67ms -step:1499/1695 train_time:146414ms step_avg:97.67ms -step:1500/1695 train_time:146512ms step_avg:97.67ms -step:1500/1695 val_loss:3.3178 train_time:146608ms step_avg:97.74ms -step:1501/1695 train_time:146633ms step_avg:97.69ms -step:1502/1695 train_time:146718ms step_avg:97.68ms -step:1503/1695 train_time:146818ms step_avg:97.68ms -step:1504/1695 train_time:146916ms step_avg:97.68ms -step:1505/1695 train_time:147013ms step_avg:97.68ms -step:1506/1695 train_time:147110ms step_avg:97.68ms -step:1507/1695 train_time:147206ms step_avg:97.68ms -step:1508/1695 train_time:147302ms step_avg:97.68ms -step:1509/1695 train_time:147399ms step_avg:97.68ms -step:1510/1695 train_time:147495ms step_avg:97.68ms -step:1511/1695 train_time:147595ms step_avg:97.68ms -step:1512/1695 train_time:147697ms step_avg:97.68ms -step:1513/1695 train_time:147797ms step_avg:97.68ms -step:1514/1695 train_time:147896ms step_avg:97.69ms -step:1515/1695 train_time:147994ms step_avg:97.69ms -step:1516/1695 train_time:148092ms step_avg:97.69ms -step:1517/1695 train_time:148190ms step_avg:97.69ms -step:1518/1695 train_time:148287ms step_avg:97.69ms -step:1519/1695 train_time:148384ms step_avg:97.69ms -step:1520/1695 train_time:148481ms step_avg:97.68ms -step:1521/1695 train_time:148578ms step_avg:97.68ms -step:1522/1695 train_time:148676ms step_avg:97.68ms -step:1523/1695 train_time:148776ms step_avg:97.69ms -step:1524/1695 train_time:148874ms step_avg:97.69ms -step:1525/1695 train_time:148973ms step_avg:97.69ms -step:1526/1695 train_time:149072ms step_avg:97.69ms -step:1527/1695 train_time:149169ms step_avg:97.69ms -step:1528/1695 train_time:149266ms step_avg:97.69ms -step:1529/1695 train_time:149362ms step_avg:97.69ms -step:1530/1695 train_time:149460ms step_avg:97.69ms -step:1531/1695 train_time:149557ms step_avg:97.69ms -step:1532/1695 train_time:149655ms step_avg:97.69ms -step:1533/1695 train_time:149753ms step_avg:97.69ms -step:1534/1695 train_time:149852ms step_avg:97.69ms -step:1535/1695 train_time:149950ms step_avg:97.69ms -step:1536/1695 train_time:150048ms step_avg:97.69ms -step:1537/1695 train_time:150146ms step_avg:97.69ms -step:1538/1695 train_time:150244ms step_avg:97.69ms -step:1539/1695 train_time:150341ms step_avg:97.69ms -step:1540/1695 train_time:150438ms step_avg:97.69ms -step:1541/1695 train_time:150535ms step_avg:97.69ms -step:1542/1695 train_time:150633ms step_avg:97.69ms -step:1543/1695 train_time:150731ms step_avg:97.69ms -step:1544/1695 train_time:150830ms step_avg:97.69ms -step:1545/1695 train_time:150928ms step_avg:97.69ms -step:1546/1695 train_time:151027ms step_avg:97.69ms -step:1547/1695 train_time:151123ms step_avg:97.69ms -step:1548/1695 train_time:151220ms step_avg:97.69ms -step:1549/1695 train_time:151317ms step_avg:97.69ms -step:1550/1695 train_time:151415ms step_avg:97.69ms -step:1551/1695 train_time:151513ms step_avg:97.69ms -step:1552/1695 train_time:151866ms step_avg:97.85ms -step:1553/1695 train_time:152044ms step_avg:97.90ms -step:1554/1695 train_time:152139ms step_avg:97.90ms -step:1555/1695 train_time:152235ms step_avg:97.90ms -step:1556/1695 train_time:152332ms step_avg:97.90ms -step:1557/1695 train_time:152428ms step_avg:97.90ms -step:1558/1695 train_time:152525ms step_avg:97.90ms -step:1559/1695 train_time:152621ms step_avg:97.90ms -step:1560/1695 train_time:152717ms step_avg:97.90ms -step:1561/1695 train_time:152815ms step_avg:97.90ms -step:1562/1695 train_time:152920ms step_avg:97.90ms -step:1563/1695 train_time:153021ms step_avg:97.90ms -step:1564/1695 train_time:153120ms step_avg:97.90ms -step:1565/1695 train_time:153218ms step_avg:97.90ms -step:1566/1695 train_time:153315ms step_avg:97.90ms -step:1567/1695 train_time:153412ms step_avg:97.90ms -step:1568/1695 train_time:153509ms step_avg:97.90ms -step:1569/1695 train_time:153606ms step_avg:97.90ms -step:1570/1695 train_time:153702ms step_avg:97.90ms -step:1571/1695 train_time:153798ms step_avg:97.90ms -step:1572/1695 train_time:153898ms step_avg:97.90ms -step:1573/1695 train_time:153999ms step_avg:97.90ms -step:1574/1695 train_time:154099ms step_avg:97.90ms -step:1575/1695 train_time:154197ms step_avg:97.90ms -step:1576/1695 train_time:154294ms step_avg:97.90ms -step:1577/1695 train_time:154393ms step_avg:97.90ms -step:1578/1695 train_time:154490ms step_avg:97.90ms -step:1579/1695 train_time:154587ms step_avg:97.90ms -step:1580/1695 train_time:154684ms step_avg:97.90ms -step:1581/1695 train_time:154781ms step_avg:97.90ms -step:1582/1695 train_time:154879ms step_avg:97.90ms -step:1583/1695 train_time:154977ms step_avg:97.90ms -step:1584/1695 train_time:155075ms step_avg:97.90ms -step:1585/1695 train_time:155174ms step_avg:97.90ms -step:1586/1695 train_time:155273ms step_avg:97.90ms -step:1587/1695 train_time:155371ms step_avg:97.90ms -step:1588/1695 train_time:155469ms step_avg:97.90ms -step:1589/1695 train_time:155566ms step_avg:97.90ms -step:1590/1695 train_time:155663ms step_avg:97.90ms -step:1591/1695 train_time:155760ms step_avg:97.90ms -step:1592/1695 train_time:155858ms step_avg:97.90ms -step:1593/1695 train_time:155956ms step_avg:97.90ms -step:1594/1695 train_time:156054ms step_avg:97.90ms -step:1595/1695 train_time:156152ms step_avg:97.90ms -step:1596/1695 train_time:156250ms step_avg:97.90ms -step:1597/1695 train_time:156350ms step_avg:97.90ms -step:1598/1695 train_time:156448ms step_avg:97.90ms -step:1599/1695 train_time:156546ms step_avg:97.90ms -step:1600/1695 train_time:156644ms step_avg:97.90ms -step:1601/1695 train_time:156741ms step_avg:97.90ms -step:1602/1695 train_time:156838ms step_avg:97.90ms -step:1603/1695 train_time:156936ms step_avg:97.90ms -step:1604/1695 train_time:157034ms step_avg:97.90ms -step:1605/1695 train_time:157133ms step_avg:97.90ms -step:1606/1695 train_time:157234ms step_avg:97.90ms -step:1607/1695 train_time:157333ms step_avg:97.90ms -step:1608/1695 train_time:157431ms step_avg:97.91ms -step:1609/1695 train_time:157529ms step_avg:97.91ms -step:1610/1695 train_time:157627ms step_avg:97.91ms -step:1611/1695 train_time:157726ms step_avg:97.91ms -step:1612/1695 train_time:157824ms step_avg:97.91ms -step:1613/1695 train_time:157921ms step_avg:97.91ms -step:1614/1695 train_time:158017ms step_avg:97.90ms -step:1615/1695 train_time:158114ms step_avg:97.90ms -step:1616/1695 train_time:158212ms step_avg:97.90ms -step:1617/1695 train_time:158312ms step_avg:97.90ms -step:1618/1695 train_time:158412ms step_avg:97.91ms -step:1619/1695 train_time:158510ms step_avg:97.91ms -step:1620/1695 train_time:158609ms step_avg:97.91ms -step:1621/1695 train_time:158708ms step_avg:97.91ms -step:1622/1695 train_time:158806ms step_avg:97.91ms -step:1623/1695 train_time:158905ms step_avg:97.91ms -step:1624/1695 train_time:159001ms step_avg:97.91ms -step:1625/1695 train_time:159097ms step_avg:97.91ms -step:1625/1695 val_loss:3.2907 train_time:159193ms step_avg:97.96ms -step:1626/1695 train_time:159217ms step_avg:97.92ms -step:1627/1695 train_time:159299ms step_avg:97.91ms -step:1628/1695 train_time:159398ms step_avg:97.91ms -step:1629/1695 train_time:159495ms step_avg:97.91ms -step:1630/1695 train_time:159593ms step_avg:97.91ms -step:1631/1695 train_time:159690ms step_avg:97.91ms -step:1632/1695 train_time:159787ms step_avg:97.91ms -step:1633/1695 train_time:159884ms step_avg:97.91ms -step:1634/1695 train_time:159981ms step_avg:97.91ms -step:1635/1695 train_time:160077ms step_avg:97.91ms -step:1636/1695 train_time:160176ms step_avg:97.91ms -step:1637/1695 train_time:160276ms step_avg:97.91ms -step:1638/1695 train_time:160375ms step_avg:97.91ms -step:1639/1695 train_time:160474ms step_avg:97.91ms -step:1640/1695 train_time:160571ms step_avg:97.91ms -step:1641/1695 train_time:160669ms step_avg:97.91ms -step:1642/1695 train_time:160766ms step_avg:97.91ms -step:1643/1695 train_time:160864ms step_avg:97.91ms -step:1644/1695 train_time:160961ms step_avg:97.91ms -step:1645/1695 train_time:161058ms step_avg:97.91ms -step:1646/1695 train_time:161157ms step_avg:97.91ms -step:1647/1695 train_time:161255ms step_avg:97.91ms -step:1648/1695 train_time:161353ms step_avg:97.91ms -step:1649/1695 train_time:161452ms step_avg:97.91ms -step:1650/1695 train_time:161551ms step_avg:97.91ms -step:1651/1695 train_time:161649ms step_avg:97.91ms -step:1652/1695 train_time:161746ms step_avg:97.91ms -step:1653/1695 train_time:161843ms step_avg:97.91ms -step:1654/1695 train_time:161940ms step_avg:97.91ms -step:1655/1695 train_time:162037ms step_avg:97.91ms -step:1656/1695 train_time:162135ms step_avg:97.91ms -step:1657/1695 train_time:162233ms step_avg:97.91ms -step:1658/1695 train_time:162332ms step_avg:97.91ms -step:1659/1695 train_time:162430ms step_avg:97.91ms -step:1660/1695 train_time:162528ms step_avg:97.91ms -step:1661/1695 train_time:162627ms step_avg:97.91ms -step:1662/1695 train_time:162724ms step_avg:97.91ms -step:1663/1695 train_time:162820ms step_avg:97.91ms -step:1664/1695 train_time:162918ms step_avg:97.91ms -step:1665/1695 train_time:163015ms step_avg:97.91ms -step:1666/1695 train_time:163114ms step_avg:97.91ms -step:1667/1695 train_time:163211ms step_avg:97.91ms -step:1668/1695 train_time:163309ms step_avg:97.91ms -step:1669/1695 train_time:163407ms step_avg:97.91ms -step:1670/1695 train_time:163505ms step_avg:97.91ms -step:1671/1695 train_time:163603ms step_avg:97.91ms -step:1672/1695 train_time:163700ms step_avg:97.91ms -step:1673/1695 train_time:163797ms step_avg:97.91ms -step:1674/1695 train_time:163894ms step_avg:97.91ms -step:1675/1695 train_time:163992ms step_avg:97.91ms -step:1676/1695 train_time:164091ms step_avg:97.91ms -step:1677/1695 train_time:164189ms step_avg:97.91ms -step:1678/1695 train_time:164287ms step_avg:97.91ms -step:1679/1695 train_time:164385ms step_avg:97.91ms -step:1680/1695 train_time:164482ms step_avg:97.91ms -step:1681/1695 train_time:164580ms step_avg:97.91ms -step:1682/1695 train_time:164677ms step_avg:97.91ms -step:1683/1695 train_time:164775ms step_avg:97.91ms -step:1684/1695 train_time:164873ms step_avg:97.91ms -step:1685/1695 train_time:164971ms step_avg:97.91ms -step:1686/1695 train_time:165069ms step_avg:97.91ms -step:1687/1695 train_time:165167ms step_avg:97.91ms -step:1688/1695 train_time:165265ms step_avg:97.91ms -step:1689/1695 train_time:165363ms step_avg:97.91ms -step:1690/1695 train_time:165461ms step_avg:97.91ms -step:1691/1695 train_time:165559ms step_avg:97.91ms -step:1692/1695 train_time:165656ms step_avg:97.91ms -step:1693/1695 train_time:165754ms step_avg:97.91ms -step:1694/1695 train_time:165851ms step_avg:97.91ms -step:1695/1695 train_time:165950ms step_avg:97.91ms -step:1695/1695 val_loss:3.2791 train_time:166045ms step_avg:97.96ms -peak memory allocated: 34073 MiB reserved: 49476 MiB diff --git a/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt b/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt deleted file mode 100644 index 9652d6c2d..000000000 --- a/records/082725_FA3/27d1e0d2-df15-41a9-9496-492a21943fb1.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:43:24 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 32C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms -step:1/1695 train_time:511ms step_avg:510.59ms -step:2/1695 train_time:534ms step_avg:266.84ms -step:3/1695 train_time:604ms step_avg:201.32ms -step:4/1695 train_time:696ms step_avg:174.01ms -step:5/1695 train_time:789ms step_avg:157.81ms -step:6/1695 train_time:882ms step_avg:147.03ms -step:7/1695 train_time:975ms step_avg:139.35ms -step:8/1695 train_time:1069ms step_avg:133.63ms -step:9/1695 train_time:1163ms step_avg:129.21ms -step:10/1695 train_time:1256ms step_avg:125.59ms -step:11/1695 train_time:1349ms step_avg:122.66ms -step:12/1695 train_time:1447ms step_avg:120.59ms -step:13/1695 train_time:1547ms step_avg:119.02ms -step:14/1695 train_time:1644ms step_avg:117.41ms -step:15/1695 train_time:1738ms step_avg:115.86ms -step:16/1695 train_time:1831ms step_avg:114.46ms -step:17/1695 train_time:1925ms step_avg:113.25ms -step:18/1695 train_time:2019ms step_avg:112.18ms -step:19/1695 train_time:2113ms step_avg:111.19ms -step:20/1695 train_time:2207ms step_avg:110.34ms -step:21/1695 train_time:2300ms step_avg:109.54ms -step:22/1695 train_time:2395ms step_avg:108.87ms -step:23/1695 train_time:2491ms step_avg:108.31ms -step:24/1695 train_time:2587ms step_avg:107.80ms -step:25/1695 train_time:2683ms step_avg:107.33ms -step:26/1695 train_time:2778ms step_avg:106.85ms -step:27/1695 train_time:2872ms step_avg:106.37ms -step:28/1695 train_time:2967ms step_avg:105.96ms -step:29/1695 train_time:3061ms step_avg:105.56ms -step:30/1695 train_time:3155ms step_avg:105.16ms -step:31/1695 train_time:3249ms step_avg:104.81ms -step:32/1695 train_time:3344ms step_avg:104.50ms -step:33/1695 train_time:3438ms step_avg:104.19ms -step:34/1695 train_time:3533ms step_avg:103.92ms -step:35/1695 train_time:3629ms step_avg:103.69ms -step:36/1695 train_time:3726ms step_avg:103.50ms -step:37/1695 train_time:3822ms step_avg:103.29ms -step:38/1695 train_time:3916ms step_avg:103.06ms -step:39/1695 train_time:4010ms step_avg:102.83ms -step:40/1695 train_time:4105ms step_avg:102.62ms -step:41/1695 train_time:4199ms step_avg:102.42ms -step:42/1695 train_time:4293ms step_avg:102.21ms -step:43/1695 train_time:4388ms step_avg:102.06ms -step:44/1695 train_time:4484ms step_avg:101.90ms -step:45/1695 train_time:4579ms step_avg:101.75ms -step:46/1695 train_time:4674ms step_avg:101.61ms -step:47/1695 train_time:4770ms step_avg:101.48ms -step:48/1695 train_time:4865ms step_avg:101.36ms -step:49/1695 train_time:4960ms step_avg:101.22ms -step:50/1695 train_time:5054ms step_avg:101.07ms -step:51/1695 train_time:5149ms step_avg:100.96ms -step:52/1695 train_time:5243ms step_avg:100.83ms -step:53/1695 train_time:5338ms step_avg:100.71ms -step:54/1695 train_time:5432ms step_avg:100.59ms -step:55/1695 train_time:5528ms step_avg:100.51ms -step:56/1695 train_time:5623ms step_avg:100.41ms -step:57/1695 train_time:5718ms step_avg:100.31ms -step:58/1695 train_time:5813ms step_avg:100.22ms -step:59/1695 train_time:5908ms step_avg:100.14ms -step:60/1695 train_time:6003ms step_avg:100.05ms -step:61/1695 train_time:6097ms step_avg:99.95ms -step:62/1695 train_time:6191ms step_avg:99.86ms -step:63/1695 train_time:6286ms step_avg:99.78ms -step:64/1695 train_time:6381ms step_avg:99.70ms -step:65/1695 train_time:6475ms step_avg:99.62ms -step:66/1695 train_time:6570ms step_avg:99.55ms -step:67/1695 train_time:6665ms step_avg:99.48ms -step:68/1695 train_time:6760ms step_avg:99.40ms -step:69/1695 train_time:6854ms step_avg:99.33ms -step:70/1695 train_time:6949ms step_avg:99.26ms -step:71/1695 train_time:7043ms step_avg:99.20ms -step:72/1695 train_time:7138ms step_avg:99.14ms -step:73/1695 train_time:7232ms step_avg:99.07ms -step:74/1695 train_time:7327ms step_avg:99.01ms -step:75/1695 train_time:7422ms step_avg:98.96ms -step:76/1695 train_time:7516ms step_avg:98.89ms -step:77/1695 train_time:7610ms step_avg:98.84ms -step:78/1695 train_time:7706ms step_avg:98.79ms -step:79/1695 train_time:7801ms step_avg:98.75ms -step:80/1695 train_time:7895ms step_avg:98.69ms -step:81/1695 train_time:7990ms step_avg:98.64ms -step:82/1695 train_time:8085ms step_avg:98.59ms -step:83/1695 train_time:8179ms step_avg:98.54ms -step:84/1695 train_time:8273ms step_avg:98.49ms -step:85/1695 train_time:8369ms step_avg:98.46ms -step:86/1695 train_time:8463ms step_avg:98.41ms -step:87/1695 train_time:8557ms step_avg:98.36ms -step:88/1695 train_time:8652ms step_avg:98.31ms -step:89/1695 train_time:8747ms step_avg:98.28ms -step:90/1695 train_time:8842ms step_avg:98.24ms -step:91/1695 train_time:8936ms step_avg:98.20ms -step:92/1695 train_time:9030ms step_avg:98.16ms -step:93/1695 train_time:9126ms step_avg:98.13ms -step:94/1695 train_time:9221ms step_avg:98.10ms -step:95/1695 train_time:9315ms step_avg:98.05ms -step:96/1695 train_time:9409ms step_avg:98.01ms -step:97/1695 train_time:9504ms step_avg:97.98ms -step:98/1695 train_time:9597ms step_avg:97.93ms -step:99/1695 train_time:9691ms step_avg:97.89ms -step:100/1695 train_time:9786ms step_avg:97.86ms -step:101/1695 train_time:9881ms step_avg:97.84ms -step:102/1695 train_time:9975ms step_avg:97.79ms -step:103/1695 train_time:10069ms step_avg:97.76ms -step:104/1695 train_time:10164ms step_avg:97.73ms -step:105/1695 train_time:10259ms step_avg:97.70ms -step:106/1695 train_time:10352ms step_avg:97.66ms -step:107/1695 train_time:10448ms step_avg:97.64ms -step:108/1695 train_time:10543ms step_avg:97.62ms -step:109/1695 train_time:10636ms step_avg:97.58ms -step:110/1695 train_time:10730ms step_avg:97.54ms -step:111/1695 train_time:10826ms step_avg:97.53ms -step:112/1695 train_time:10922ms step_avg:97.51ms -step:113/1695 train_time:11016ms step_avg:97.48ms -step:114/1695 train_time:11110ms step_avg:97.45ms -step:115/1695 train_time:11205ms step_avg:97.44ms -step:116/1695 train_time:11300ms step_avg:97.41ms -step:117/1695 train_time:11393ms step_avg:97.37ms -step:118/1695 train_time:11487ms step_avg:97.35ms -step:119/1695 train_time:11582ms step_avg:97.33ms -step:120/1695 train_time:11676ms step_avg:97.30ms -step:121/1695 train_time:11770ms step_avg:97.27ms -step:122/1695 train_time:11864ms step_avg:97.25ms -step:123/1695 train_time:11959ms step_avg:97.22ms -step:124/1695 train_time:12052ms step_avg:97.20ms -step:125/1695 train_time:12148ms step_avg:97.18ms -step:125/1695 val_loss:4.3128 train_time:12241ms step_avg:97.92ms -step:126/1695 train_time:12267ms step_avg:97.36ms -step:127/1695 train_time:12345ms step_avg:97.20ms -step:128/1695 train_time:12446ms step_avg:97.23ms -step:129/1695 train_time:12540ms step_avg:97.21ms -step:130/1695 train_time:12635ms step_avg:97.19ms -step:131/1695 train_time:12728ms step_avg:97.16ms -step:132/1695 train_time:12821ms step_avg:97.13ms -step:133/1695 train_time:12915ms step_avg:97.11ms -step:134/1695 train_time:13008ms step_avg:97.08ms -step:135/1695 train_time:13102ms step_avg:97.05ms -step:136/1695 train_time:13196ms step_avg:97.03ms -step:137/1695 train_time:13291ms step_avg:97.01ms -step:138/1695 train_time:13388ms step_avg:97.01ms -step:139/1695 train_time:13482ms step_avg:96.99ms -step:140/1695 train_time:13577ms step_avg:96.98ms -step:141/1695 train_time:13672ms step_avg:96.96ms -step:142/1695 train_time:13766ms step_avg:96.94ms -step:143/1695 train_time:13859ms step_avg:96.92ms -step:144/1695 train_time:13954ms step_avg:96.90ms -step:145/1695 train_time:14047ms step_avg:96.88ms -step:146/1695 train_time:14140ms step_avg:96.85ms -step:147/1695 train_time:14235ms step_avg:96.84ms -step:148/1695 train_time:14331ms step_avg:96.83ms -step:149/1695 train_time:14426ms step_avg:96.82ms -step:150/1695 train_time:14520ms step_avg:96.80ms -step:151/1695 train_time:14614ms step_avg:96.78ms -step:152/1695 train_time:14709ms step_avg:96.77ms -step:153/1695 train_time:14802ms step_avg:96.75ms -step:154/1695 train_time:14896ms step_avg:96.73ms -step:155/1695 train_time:14991ms step_avg:96.72ms -step:156/1695 train_time:15084ms step_avg:96.69ms -step:157/1695 train_time:15178ms step_avg:96.68ms -step:158/1695 train_time:15272ms step_avg:96.66ms -step:159/1695 train_time:15366ms step_avg:96.64ms -step:160/1695 train_time:15461ms step_avg:96.63ms -step:161/1695 train_time:15556ms step_avg:96.62ms -step:162/1695 train_time:15651ms step_avg:96.61ms -step:163/1695 train_time:15745ms step_avg:96.59ms -step:164/1695 train_time:15839ms step_avg:96.58ms -step:165/1695 train_time:15934ms step_avg:96.57ms -step:166/1695 train_time:16029ms step_avg:96.56ms -step:167/1695 train_time:16123ms step_avg:96.55ms -step:168/1695 train_time:16217ms step_avg:96.53ms -step:169/1695 train_time:16312ms step_avg:96.52ms -step:170/1695 train_time:16406ms step_avg:96.51ms -step:171/1695 train_time:16501ms step_avg:96.50ms -step:172/1695 train_time:16596ms step_avg:96.49ms -step:173/1695 train_time:16939ms step_avg:97.91ms -step:174/1695 train_time:17041ms step_avg:97.94ms -step:175/1695 train_time:17135ms step_avg:97.91ms -step:176/1695 train_time:17228ms step_avg:97.88ms -step:177/1695 train_time:17321ms step_avg:97.86ms -step:178/1695 train_time:17414ms step_avg:97.83ms -step:179/1695 train_time:17509ms step_avg:97.82ms -step:180/1695 train_time:17602ms step_avg:97.79ms -step:181/1695 train_time:17695ms step_avg:97.77ms -step:182/1695 train_time:17788ms step_avg:97.74ms -step:183/1695 train_time:17884ms step_avg:97.73ms -step:184/1695 train_time:17983ms step_avg:97.73ms -step:185/1695 train_time:18078ms step_avg:97.72ms -step:186/1695 train_time:18174ms step_avg:97.71ms -step:187/1695 train_time:18268ms step_avg:97.69ms -step:188/1695 train_time:18361ms step_avg:97.67ms -step:189/1695 train_time:18455ms step_avg:97.65ms -step:190/1695 train_time:18548ms step_avg:97.62ms -step:191/1695 train_time:18641ms step_avg:97.60ms -step:192/1695 train_time:18735ms step_avg:97.58ms -step:193/1695 train_time:18830ms step_avg:97.56ms -step:194/1695 train_time:18924ms step_avg:97.55ms -step:195/1695 train_time:19018ms step_avg:97.53ms -step:196/1695 train_time:19114ms step_avg:97.52ms -step:197/1695 train_time:19209ms step_avg:97.51ms -step:198/1695 train_time:19302ms step_avg:97.49ms -step:199/1695 train_time:19397ms step_avg:97.47ms -step:200/1695 train_time:19490ms step_avg:97.45ms -step:201/1695 train_time:19583ms step_avg:97.43ms -step:202/1695 train_time:19677ms step_avg:97.41ms -step:203/1695 train_time:19771ms step_avg:97.40ms -step:204/1695 train_time:19866ms step_avg:97.38ms -step:205/1695 train_time:19960ms step_avg:97.37ms -step:206/1695 train_time:20056ms step_avg:97.36ms -step:207/1695 train_time:20151ms step_avg:97.35ms -step:208/1695 train_time:20244ms step_avg:97.33ms -step:209/1695 train_time:20338ms step_avg:97.31ms -step:210/1695 train_time:20433ms step_avg:97.30ms -step:211/1695 train_time:20527ms step_avg:97.29ms -step:212/1695 train_time:20621ms step_avg:97.27ms -step:213/1695 train_time:20715ms step_avg:97.26ms -step:214/1695 train_time:20810ms step_avg:97.24ms -step:215/1695 train_time:20904ms step_avg:97.23ms -step:216/1695 train_time:20998ms step_avg:97.21ms -step:217/1695 train_time:21094ms step_avg:97.21ms -step:218/1695 train_time:21188ms step_avg:97.19ms -step:219/1695 train_time:21281ms step_avg:97.17ms -step:220/1695 train_time:21376ms step_avg:97.16ms -step:221/1695 train_time:21470ms step_avg:97.15ms -step:222/1695 train_time:21563ms step_avg:97.13ms -step:223/1695 train_time:21657ms step_avg:97.12ms -step:224/1695 train_time:21751ms step_avg:97.10ms -step:225/1695 train_time:21845ms step_avg:97.09ms -step:226/1695 train_time:21939ms step_avg:97.07ms -step:227/1695 train_time:22034ms step_avg:97.06ms -step:228/1695 train_time:22128ms step_avg:97.05ms -step:229/1695 train_time:22221ms step_avg:97.03ms -step:230/1695 train_time:22315ms step_avg:97.02ms -step:231/1695 train_time:22410ms step_avg:97.01ms -step:232/1695 train_time:22505ms step_avg:97.00ms -step:233/1695 train_time:22599ms step_avg:96.99ms -step:234/1695 train_time:22693ms step_avg:96.98ms -step:235/1695 train_time:22787ms step_avg:96.97ms -step:236/1695 train_time:22881ms step_avg:96.95ms -step:237/1695 train_time:22975ms step_avg:96.94ms -step:238/1695 train_time:23071ms step_avg:96.94ms -step:239/1695 train_time:23166ms step_avg:96.93ms -step:240/1695 train_time:23259ms step_avg:96.91ms -step:241/1695 train_time:23354ms step_avg:96.90ms -step:242/1695 train_time:23448ms step_avg:96.89ms -step:243/1695 train_time:23541ms step_avg:96.88ms -step:244/1695 train_time:23637ms step_avg:96.87ms -step:245/1695 train_time:23731ms step_avg:96.86ms -step:246/1695 train_time:23825ms step_avg:96.85ms -step:247/1695 train_time:23919ms step_avg:96.84ms -step:248/1695 train_time:24014ms step_avg:96.83ms -step:249/1695 train_time:24109ms step_avg:96.83ms -step:250/1695 train_time:24204ms step_avg:96.81ms -step:250/1695 val_loss:3.9758 train_time:24295ms step_avg:97.18ms -step:251/1695 train_time:24320ms step_avg:96.89ms -step:252/1695 train_time:24399ms step_avg:96.82ms -step:253/1695 train_time:24500ms step_avg:96.84ms -step:254/1695 train_time:24595ms step_avg:96.83ms -step:255/1695 train_time:24689ms step_avg:96.82ms -step:256/1695 train_time:24782ms step_avg:96.81ms -step:257/1695 train_time:24876ms step_avg:96.79ms -step:258/1695 train_time:24969ms step_avg:96.78ms -step:259/1695 train_time:25062ms step_avg:96.76ms -step:260/1695 train_time:25155ms step_avg:96.75ms -step:261/1695 train_time:25248ms step_avg:96.74ms -step:262/1695 train_time:25344ms step_avg:96.73ms -step:263/1695 train_time:25441ms step_avg:96.73ms -step:264/1695 train_time:25538ms step_avg:96.73ms -step:265/1695 train_time:25632ms step_avg:96.73ms -step:266/1695 train_time:25726ms step_avg:96.71ms -step:267/1695 train_time:25820ms step_avg:96.70ms -step:268/1695 train_time:25914ms step_avg:96.69ms -step:269/1695 train_time:26007ms step_avg:96.68ms -step:270/1695 train_time:26100ms step_avg:96.67ms -step:271/1695 train_time:26193ms step_avg:96.65ms -step:272/1695 train_time:26287ms step_avg:96.64ms -step:273/1695 train_time:26382ms step_avg:96.64ms -step:274/1695 train_time:26478ms step_avg:96.63ms -step:275/1695 train_time:26574ms step_avg:96.63ms -step:276/1695 train_time:26668ms step_avg:96.62ms -step:277/1695 train_time:26762ms step_avg:96.61ms -step:278/1695 train_time:26856ms step_avg:96.61ms -step:279/1695 train_time:26950ms step_avg:96.60ms -step:280/1695 train_time:27044ms step_avg:96.59ms -step:281/1695 train_time:27137ms step_avg:96.57ms -step:282/1695 train_time:27231ms step_avg:96.56ms -step:283/1695 train_time:27324ms step_avg:96.55ms -step:284/1695 train_time:27419ms step_avg:96.54ms -step:285/1695 train_time:27513ms step_avg:96.54ms -step:286/1695 train_time:27608ms step_avg:96.53ms -step:287/1695 train_time:27702ms step_avg:96.52ms -step:288/1695 train_time:27796ms step_avg:96.51ms -step:289/1695 train_time:27891ms step_avg:96.51ms -step:290/1695 train_time:27984ms step_avg:96.50ms -step:291/1695 train_time:28078ms step_avg:96.49ms -step:292/1695 train_time:28171ms step_avg:96.48ms -step:293/1695 train_time:28265ms step_avg:96.47ms -step:294/1695 train_time:28359ms step_avg:96.46ms -step:295/1695 train_time:28453ms step_avg:96.45ms -step:296/1695 train_time:28547ms step_avg:96.44ms -step:297/1695 train_time:28641ms step_avg:96.43ms -step:298/1695 train_time:28735ms step_avg:96.43ms -step:299/1695 train_time:28831ms step_avg:96.42ms -step:300/1695 train_time:28925ms step_avg:96.42ms -step:301/1695 train_time:29019ms step_avg:96.41ms -step:302/1695 train_time:29114ms step_avg:96.40ms -step:303/1695 train_time:29208ms step_avg:96.39ms -step:304/1695 train_time:29301ms step_avg:96.39ms -step:305/1695 train_time:29396ms step_avg:96.38ms -step:306/1695 train_time:29490ms step_avg:96.37ms -step:307/1695 train_time:29584ms step_avg:96.37ms -step:308/1695 train_time:29678ms step_avg:96.36ms -step:309/1695 train_time:29772ms step_avg:96.35ms -step:310/1695 train_time:29866ms step_avg:96.34ms -step:311/1695 train_time:29960ms step_avg:96.33ms -step:312/1695 train_time:30055ms step_avg:96.33ms -step:313/1695 train_time:30149ms step_avg:96.32ms -step:314/1695 train_time:30242ms step_avg:96.31ms -step:315/1695 train_time:30337ms step_avg:96.31ms -step:316/1695 train_time:30431ms step_avg:96.30ms -step:317/1695 train_time:30525ms step_avg:96.29ms -step:318/1695 train_time:30619ms step_avg:96.29ms -step:319/1695 train_time:30713ms step_avg:96.28ms -step:320/1695 train_time:30807ms step_avg:96.27ms -step:321/1695 train_time:30900ms step_avg:96.26ms -step:322/1695 train_time:30995ms step_avg:96.26ms -step:323/1695 train_time:31089ms step_avg:96.25ms -step:324/1695 train_time:31182ms step_avg:96.24ms -step:325/1695 train_time:31277ms step_avg:96.24ms -step:326/1695 train_time:31371ms step_avg:96.23ms -step:327/1695 train_time:31465ms step_avg:96.22ms -step:328/1695 train_time:31559ms step_avg:96.22ms -step:329/1695 train_time:31654ms step_avg:96.21ms -step:330/1695 train_time:31750ms step_avg:96.21ms -step:331/1695 train_time:31843ms step_avg:96.20ms -step:332/1695 train_time:31938ms step_avg:96.20ms -step:333/1695 train_time:32032ms step_avg:96.19ms -step:334/1695 train_time:32126ms step_avg:96.19ms -step:335/1695 train_time:32220ms step_avg:96.18ms -step:336/1695 train_time:32315ms step_avg:96.18ms -step:337/1695 train_time:32410ms step_avg:96.17ms -step:338/1695 train_time:32503ms step_avg:96.16ms -step:339/1695 train_time:32597ms step_avg:96.16ms -step:340/1695 train_time:32692ms step_avg:96.15ms -step:341/1695 train_time:32785ms step_avg:96.14ms -step:342/1695 train_time:32879ms step_avg:96.14ms -step:343/1695 train_time:32974ms step_avg:96.13ms -step:344/1695 train_time:33069ms step_avg:96.13ms -step:345/1695 train_time:33399ms step_avg:96.81ms -step:346/1695 train_time:33523ms step_avg:96.89ms -step:347/1695 train_time:33615ms step_avg:96.87ms -step:348/1695 train_time:33709ms step_avg:96.87ms -step:349/1695 train_time:33802ms step_avg:96.85ms -step:350/1695 train_time:33895ms step_avg:96.84ms -step:351/1695 train_time:33988ms step_avg:96.83ms -step:352/1695 train_time:34081ms step_avg:96.82ms -step:353/1695 train_time:34174ms step_avg:96.81ms -step:354/1695 train_time:34267ms step_avg:96.80ms -step:355/1695 train_time:34364ms step_avg:96.80ms -step:356/1695 train_time:34462ms step_avg:96.80ms -step:357/1695 train_time:34558ms step_avg:96.80ms -step:358/1695 train_time:34653ms step_avg:96.80ms -step:359/1695 train_time:34747ms step_avg:96.79ms -step:360/1695 train_time:34840ms step_avg:96.78ms -step:361/1695 train_time:34934ms step_avg:96.77ms -step:362/1695 train_time:35026ms step_avg:96.76ms -step:363/1695 train_time:35119ms step_avg:96.75ms -step:364/1695 train_time:35213ms step_avg:96.74ms -step:365/1695 train_time:35306ms step_avg:96.73ms -step:366/1695 train_time:35402ms step_avg:96.73ms -step:367/1695 train_time:35497ms step_avg:96.72ms -step:368/1695 train_time:35591ms step_avg:96.72ms -step:369/1695 train_time:35686ms step_avg:96.71ms -step:370/1695 train_time:35780ms step_avg:96.70ms -step:371/1695 train_time:35874ms step_avg:96.69ms -step:372/1695 train_time:35967ms step_avg:96.68ms -step:373/1695 train_time:36060ms step_avg:96.68ms -step:374/1695 train_time:36154ms step_avg:96.67ms -step:375/1695 train_time:36248ms step_avg:96.66ms -step:375/1695 val_loss:3.8203 train_time:36339ms step_avg:96.90ms -step:376/1695 train_time:36364ms step_avg:96.71ms -step:377/1695 train_time:36442ms step_avg:96.66ms -step:378/1695 train_time:36539ms step_avg:96.66ms -step:379/1695 train_time:36633ms step_avg:96.66ms -step:380/1695 train_time:36726ms step_avg:96.65ms -step:381/1695 train_time:36820ms step_avg:96.64ms -step:382/1695 train_time:36912ms step_avg:96.63ms -step:383/1695 train_time:37005ms step_avg:96.62ms -step:384/1695 train_time:37098ms step_avg:96.61ms -step:385/1695 train_time:37190ms step_avg:96.60ms -step:386/1695 train_time:37284ms step_avg:96.59ms -step:387/1695 train_time:37379ms step_avg:96.59ms -step:388/1695 train_time:37475ms step_avg:96.59ms -step:389/1695 train_time:37570ms step_avg:96.58ms -step:390/1695 train_time:37665ms step_avg:96.58ms -step:391/1695 train_time:37759ms step_avg:96.57ms -step:392/1695 train_time:37852ms step_avg:96.56ms -step:393/1695 train_time:37946ms step_avg:96.55ms -step:394/1695 train_time:38039ms step_avg:96.55ms -step:395/1695 train_time:38131ms step_avg:96.53ms -step:396/1695 train_time:38225ms step_avg:96.53ms -step:397/1695 train_time:38319ms step_avg:96.52ms -step:398/1695 train_time:38413ms step_avg:96.52ms -step:399/1695 train_time:38508ms step_avg:96.51ms -step:400/1695 train_time:38604ms step_avg:96.51ms -step:401/1695 train_time:38699ms step_avg:96.51ms -step:402/1695 train_time:38792ms step_avg:96.50ms -step:403/1695 train_time:38886ms step_avg:96.49ms -step:404/1695 train_time:38980ms step_avg:96.49ms -step:405/1695 train_time:39073ms step_avg:96.48ms -step:406/1695 train_time:39166ms step_avg:96.47ms -step:407/1695 train_time:39260ms step_avg:96.46ms -step:408/1695 train_time:39354ms step_avg:96.46ms -step:409/1695 train_time:39448ms step_avg:96.45ms -step:410/1695 train_time:39543ms step_avg:96.45ms -step:411/1695 train_time:39638ms step_avg:96.44ms -step:412/1695 train_time:39732ms step_avg:96.44ms -step:413/1695 train_time:39825ms step_avg:96.43ms -step:414/1695 train_time:39919ms step_avg:96.42ms -step:415/1695 train_time:40012ms step_avg:96.41ms -step:416/1695 train_time:40105ms step_avg:96.41ms -step:417/1695 train_time:40199ms step_avg:96.40ms -step:418/1695 train_time:40292ms step_avg:96.39ms -step:419/1695 train_time:40386ms step_avg:96.39ms -step:420/1695 train_time:40481ms step_avg:96.38ms -step:421/1695 train_time:40575ms step_avg:96.38ms -step:422/1695 train_time:40669ms step_avg:96.37ms -step:423/1695 train_time:40764ms step_avg:96.37ms -step:424/1695 train_time:40858ms step_avg:96.36ms -step:425/1695 train_time:40952ms step_avg:96.36ms -step:426/1695 train_time:41046ms step_avg:96.35ms -step:427/1695 train_time:41140ms step_avg:96.35ms -step:428/1695 train_time:41233ms step_avg:96.34ms -step:429/1695 train_time:41327ms step_avg:96.33ms -step:430/1695 train_time:41420ms step_avg:96.33ms -step:431/1695 train_time:41514ms step_avg:96.32ms -step:432/1695 train_time:41608ms step_avg:96.32ms -step:433/1695 train_time:41702ms step_avg:96.31ms -step:434/1695 train_time:41797ms step_avg:96.31ms -step:435/1695 train_time:41890ms step_avg:96.30ms -step:436/1695 train_time:41985ms step_avg:96.30ms -step:437/1695 train_time:42079ms step_avg:96.29ms -step:438/1695 train_time:42173ms step_avg:96.29ms -step:439/1695 train_time:42267ms step_avg:96.28ms -step:440/1695 train_time:42361ms step_avg:96.28ms -step:441/1695 train_time:42455ms step_avg:96.27ms -step:442/1695 train_time:42549ms step_avg:96.26ms -step:443/1695 train_time:42643ms step_avg:96.26ms -step:444/1695 train_time:42737ms step_avg:96.25ms -step:445/1695 train_time:42830ms step_avg:96.25ms -step:446/1695 train_time:42924ms step_avg:96.24ms -step:447/1695 train_time:43019ms step_avg:96.24ms -step:448/1695 train_time:43114ms step_avg:96.24ms -step:449/1695 train_time:43208ms step_avg:96.23ms -step:450/1695 train_time:43302ms step_avg:96.23ms -step:451/1695 train_time:43397ms step_avg:96.22ms -step:452/1695 train_time:43490ms step_avg:96.22ms -step:453/1695 train_time:43584ms step_avg:96.21ms -step:454/1695 train_time:43679ms step_avg:96.21ms -step:455/1695 train_time:43773ms step_avg:96.20ms -step:456/1695 train_time:43866ms step_avg:96.20ms -step:457/1695 train_time:43960ms step_avg:96.19ms -step:458/1695 train_time:44055ms step_avg:96.19ms -step:459/1695 train_time:44149ms step_avg:96.18ms -step:460/1695 train_time:44243ms step_avg:96.18ms -step:461/1695 train_time:44338ms step_avg:96.18ms -step:462/1695 train_time:44432ms step_avg:96.17ms -step:463/1695 train_time:44526ms step_avg:96.17ms -step:464/1695 train_time:44621ms step_avg:96.16ms -step:465/1695 train_time:44714ms step_avg:96.16ms -step:466/1695 train_time:44808ms step_avg:96.15ms -step:467/1695 train_time:44902ms step_avg:96.15ms -step:468/1695 train_time:44996ms step_avg:96.14ms -step:469/1695 train_time:45089ms step_avg:96.14ms -step:470/1695 train_time:45184ms step_avg:96.14ms -step:471/1695 train_time:45278ms step_avg:96.13ms -step:472/1695 train_time:45372ms step_avg:96.13ms -step:473/1695 train_time:45466ms step_avg:96.12ms -step:474/1695 train_time:45560ms step_avg:96.12ms -step:475/1695 train_time:45653ms step_avg:96.11ms -step:476/1695 train_time:45748ms step_avg:96.11ms -step:477/1695 train_time:45842ms step_avg:96.10ms -step:478/1695 train_time:45935ms step_avg:96.10ms -step:479/1695 train_time:46029ms step_avg:96.09ms -step:480/1695 train_time:46123ms step_avg:96.09ms -step:481/1695 train_time:46217ms step_avg:96.08ms -step:482/1695 train_time:46311ms step_avg:96.08ms -step:483/1695 train_time:46405ms step_avg:96.08ms -step:484/1695 train_time:46499ms step_avg:96.07ms -step:485/1695 train_time:46593ms step_avg:96.07ms -step:486/1695 train_time:46688ms step_avg:96.07ms -step:487/1695 train_time:46782ms step_avg:96.06ms -step:488/1695 train_time:46876ms step_avg:96.06ms -step:489/1695 train_time:46969ms step_avg:96.05ms -step:490/1695 train_time:47064ms step_avg:96.05ms -step:491/1695 train_time:47158ms step_avg:96.05ms -step:492/1695 train_time:47252ms step_avg:96.04ms -step:493/1695 train_time:47346ms step_avg:96.04ms -step:494/1695 train_time:47442ms step_avg:96.04ms -step:495/1695 train_time:47536ms step_avg:96.03ms -step:496/1695 train_time:47629ms step_avg:96.03ms -step:497/1695 train_time:47724ms step_avg:96.02ms -step:498/1695 train_time:47820ms step_avg:96.02ms -step:499/1695 train_time:47914ms step_avg:96.02ms -step:500/1695 train_time:48008ms step_avg:96.02ms -step:500/1695 val_loss:3.7161 train_time:48100ms step_avg:96.20ms -step:501/1695 train_time:48124ms step_avg:96.06ms -step:502/1695 train_time:48204ms step_avg:96.02ms -step:503/1695 train_time:48302ms step_avg:96.03ms -step:504/1695 train_time:48397ms step_avg:96.03ms -step:505/1695 train_time:48491ms step_avg:96.02ms -step:506/1695 train_time:48584ms step_avg:96.02ms -step:507/1695 train_time:48678ms step_avg:96.01ms -step:508/1695 train_time:48771ms step_avg:96.01ms -step:509/1695 train_time:48864ms step_avg:96.00ms -step:510/1695 train_time:48957ms step_avg:95.99ms -step:511/1695 train_time:49050ms step_avg:95.99ms -step:512/1695 train_time:49146ms step_avg:95.99ms -step:513/1695 train_time:49242ms step_avg:95.99ms -step:514/1695 train_time:49337ms step_avg:95.99ms -step:515/1695 train_time:49432ms step_avg:95.98ms -step:516/1695 train_time:49525ms step_avg:95.98ms -step:517/1695 train_time:49619ms step_avg:95.97ms -step:518/1695 train_time:49713ms step_avg:95.97ms -step:519/1695 train_time:50082ms step_avg:96.50ms -step:520/1695 train_time:50228ms step_avg:96.59ms -step:521/1695 train_time:50320ms step_avg:96.58ms -step:522/1695 train_time:50412ms step_avg:96.58ms -step:523/1695 train_time:50505ms step_avg:96.57ms -step:524/1695 train_time:50598ms step_avg:96.56ms -step:525/1695 train_time:50691ms step_avg:96.55ms -step:526/1695 train_time:50784ms step_avg:96.55ms -step:527/1695 train_time:50878ms step_avg:96.54ms -step:528/1695 train_time:50971ms step_avg:96.54ms -step:529/1695 train_time:51069ms step_avg:96.54ms -step:530/1695 train_time:51167ms step_avg:96.54ms -step:531/1695 train_time:51264ms step_avg:96.54ms -step:532/1695 train_time:51358ms step_avg:96.54ms -step:533/1695 train_time:51452ms step_avg:96.53ms -step:534/1695 train_time:51545ms step_avg:96.53ms -step:535/1695 train_time:51638ms step_avg:96.52ms -step:536/1695 train_time:51732ms step_avg:96.51ms -step:537/1695 train_time:51824ms step_avg:96.51ms -step:538/1695 train_time:51918ms step_avg:96.50ms -step:539/1695 train_time:52014ms step_avg:96.50ms -step:540/1695 train_time:52110ms step_avg:96.50ms -step:541/1695 train_time:52204ms step_avg:96.50ms -step:542/1695 train_time:52299ms step_avg:96.49ms -step:543/1695 train_time:52393ms step_avg:96.49ms -step:544/1695 train_time:52486ms step_avg:96.48ms -step:545/1695 train_time:52580ms step_avg:96.48ms -step:546/1695 train_time:52674ms step_avg:96.47ms -step:547/1695 train_time:52767ms step_avg:96.47ms -step:548/1695 train_time:52860ms step_avg:96.46ms -step:549/1695 train_time:52954ms step_avg:96.46ms -step:550/1695 train_time:53049ms step_avg:96.45ms -step:551/1695 train_time:53143ms step_avg:96.45ms -step:552/1695 train_time:53238ms step_avg:96.45ms -step:553/1695 train_time:53331ms step_avg:96.44ms -step:554/1695 train_time:53425ms step_avg:96.43ms -step:555/1695 train_time:53519ms step_avg:96.43ms -step:556/1695 train_time:53614ms step_avg:96.43ms -step:557/1695 train_time:53708ms step_avg:96.42ms -step:558/1695 train_time:53801ms step_avg:96.42ms -step:559/1695 train_time:53895ms step_avg:96.41ms -step:560/1695 train_time:53989ms step_avg:96.41ms -step:561/1695 train_time:54083ms step_avg:96.40ms -step:562/1695 train_time:54178ms step_avg:96.40ms -step:563/1695 train_time:54273ms step_avg:96.40ms -step:564/1695 train_time:54367ms step_avg:96.40ms -step:565/1695 train_time:54461ms step_avg:96.39ms -step:566/1695 train_time:54555ms step_avg:96.39ms -step:567/1695 train_time:54650ms step_avg:96.38ms -step:568/1695 train_time:54746ms step_avg:96.38ms -step:569/1695 train_time:54841ms step_avg:96.38ms -step:570/1695 train_time:54938ms step_avg:96.38ms -step:571/1695 train_time:55035ms step_avg:96.38ms -step:572/1695 train_time:55131ms step_avg:96.38ms -step:573/1695 train_time:55228ms step_avg:96.38ms -step:574/1695 train_time:55323ms step_avg:96.38ms -step:575/1695 train_time:55420ms step_avg:96.38ms -step:576/1695 train_time:55517ms step_avg:96.38ms -step:577/1695 train_time:55614ms step_avg:96.38ms -step:578/1695 train_time:55711ms step_avg:96.39ms -step:579/1695 train_time:55806ms step_avg:96.38ms -step:580/1695 train_time:55903ms step_avg:96.38ms -step:581/1695 train_time:55999ms step_avg:96.38ms -step:582/1695 train_time:56097ms step_avg:96.39ms -step:583/1695 train_time:56194ms step_avg:96.39ms -step:584/1695 train_time:56290ms step_avg:96.39ms -step:585/1695 train_time:56386ms step_avg:96.39ms -step:586/1695 train_time:56482ms step_avg:96.39ms -step:587/1695 train_time:56578ms step_avg:96.39ms -step:588/1695 train_time:56675ms step_avg:96.39ms -step:589/1695 train_time:56772ms step_avg:96.39ms -step:590/1695 train_time:56868ms step_avg:96.39ms -step:591/1695 train_time:56964ms step_avg:96.39ms -step:592/1695 train_time:57061ms step_avg:96.39ms -step:593/1695 train_time:57157ms step_avg:96.39ms -step:594/1695 train_time:57254ms step_avg:96.39ms -step:595/1695 train_time:57350ms step_avg:96.39ms -step:596/1695 train_time:57446ms step_avg:96.39ms -step:597/1695 train_time:57542ms step_avg:96.39ms -step:598/1695 train_time:57638ms step_avg:96.38ms -step:599/1695 train_time:57735ms step_avg:96.39ms -step:600/1695 train_time:57831ms step_avg:96.39ms -step:601/1695 train_time:57926ms step_avg:96.38ms -step:602/1695 train_time:58022ms step_avg:96.38ms -step:603/1695 train_time:58118ms step_avg:96.38ms -step:604/1695 train_time:58215ms step_avg:96.38ms -step:605/1695 train_time:58312ms step_avg:96.38ms -step:606/1695 train_time:58408ms step_avg:96.38ms -step:607/1695 train_time:58504ms step_avg:96.38ms -step:608/1695 train_time:58600ms step_avg:96.38ms -step:609/1695 train_time:58696ms step_avg:96.38ms -step:610/1695 train_time:58794ms step_avg:96.38ms -step:611/1695 train_time:58891ms step_avg:96.38ms -step:612/1695 train_time:58987ms step_avg:96.38ms -step:613/1695 train_time:59083ms step_avg:96.38ms -step:614/1695 train_time:59180ms step_avg:96.38ms -step:615/1695 train_time:59276ms step_avg:96.38ms -step:616/1695 train_time:59373ms step_avg:96.39ms -step:617/1695 train_time:59470ms step_avg:96.39ms -step:618/1695 train_time:59565ms step_avg:96.38ms -step:619/1695 train_time:59661ms step_avg:96.38ms -step:620/1695 train_time:59758ms step_avg:96.38ms -step:621/1695 train_time:59855ms step_avg:96.39ms -step:622/1695 train_time:59952ms step_avg:96.39ms -step:623/1695 train_time:60048ms step_avg:96.39ms -step:624/1695 train_time:60143ms step_avg:96.38ms -step:625/1695 train_time:60239ms step_avg:96.38ms -step:625/1695 val_loss:3.6203 train_time:60334ms step_avg:96.53ms -step:626/1695 train_time:60358ms step_avg:96.42ms -step:627/1695 train_time:60442ms step_avg:96.40ms -step:628/1695 train_time:60540ms step_avg:96.40ms -step:629/1695 train_time:60637ms step_avg:96.40ms -step:630/1695 train_time:60732ms step_avg:96.40ms -step:631/1695 train_time:60827ms step_avg:96.40ms -step:632/1695 train_time:60921ms step_avg:96.39ms -step:633/1695 train_time:61017ms step_avg:96.39ms -step:634/1695 train_time:61112ms step_avg:96.39ms -step:635/1695 train_time:61208ms step_avg:96.39ms -step:636/1695 train_time:61305ms step_avg:96.39ms -step:637/1695 train_time:61405ms step_avg:96.40ms -step:638/1695 train_time:61502ms step_avg:96.40ms -step:639/1695 train_time:61599ms step_avg:96.40ms -step:640/1695 train_time:61696ms step_avg:96.40ms -step:641/1695 train_time:61793ms step_avg:96.40ms -step:642/1695 train_time:61887ms step_avg:96.40ms -step:643/1695 train_time:61983ms step_avg:96.40ms -step:644/1695 train_time:62078ms step_avg:96.39ms -step:645/1695 train_time:62175ms step_avg:96.40ms -step:646/1695 train_time:62272ms step_avg:96.40ms -step:647/1695 train_time:62368ms step_avg:96.40ms -step:648/1695 train_time:62465ms step_avg:96.40ms -step:649/1695 train_time:62562ms step_avg:96.40ms -step:650/1695 train_time:62660ms step_avg:96.40ms -step:651/1695 train_time:62756ms step_avg:96.40ms -step:652/1695 train_time:62852ms step_avg:96.40ms -step:653/1695 train_time:62947ms step_avg:96.40ms -step:654/1695 train_time:63041ms step_avg:96.39ms -step:655/1695 train_time:63139ms step_avg:96.39ms -step:656/1695 train_time:63237ms step_avg:96.40ms -step:657/1695 train_time:63335ms step_avg:96.40ms -step:658/1695 train_time:63431ms step_avg:96.40ms -step:659/1695 train_time:63527ms step_avg:96.40ms -step:660/1695 train_time:63623ms step_avg:96.40ms -step:661/1695 train_time:63719ms step_avg:96.40ms -step:662/1695 train_time:63815ms step_avg:96.40ms -step:663/1695 train_time:63911ms step_avg:96.40ms -step:664/1695 train_time:64007ms step_avg:96.40ms -step:665/1695 train_time:64103ms step_avg:96.40ms -step:666/1695 train_time:64199ms step_avg:96.39ms -step:667/1695 train_time:64295ms step_avg:96.39ms -step:668/1695 train_time:64391ms step_avg:96.39ms -step:669/1695 train_time:64487ms step_avg:96.39ms -step:670/1695 train_time:64583ms step_avg:96.39ms -step:671/1695 train_time:64679ms step_avg:96.39ms -step:672/1695 train_time:64775ms step_avg:96.39ms -step:673/1695 train_time:64871ms step_avg:96.39ms -step:674/1695 train_time:64967ms step_avg:96.39ms -step:675/1695 train_time:65062ms step_avg:96.39ms -step:676/1695 train_time:65158ms step_avg:96.39ms -step:677/1695 train_time:65255ms step_avg:96.39ms -step:678/1695 train_time:65351ms step_avg:96.39ms -step:679/1695 train_time:65447ms step_avg:96.39ms -step:680/1695 train_time:65542ms step_avg:96.39ms -step:681/1695 train_time:65639ms step_avg:96.39ms -step:682/1695 train_time:65735ms step_avg:96.39ms -step:683/1695 train_time:65832ms step_avg:96.39ms -step:684/1695 train_time:65929ms step_avg:96.39ms -step:685/1695 train_time:66025ms step_avg:96.39ms -step:686/1695 train_time:66120ms step_avg:96.38ms -step:687/1695 train_time:66216ms step_avg:96.38ms -step:688/1695 train_time:66313ms step_avg:96.39ms -step:689/1695 train_time:66410ms step_avg:96.39ms -step:690/1695 train_time:66505ms step_avg:96.38ms -step:691/1695 train_time:66945ms step_avg:96.88ms -step:692/1695 train_time:67026ms step_avg:96.86ms -step:693/1695 train_time:67121ms step_avg:96.86ms -step:694/1695 train_time:67216ms step_avg:96.85ms -step:695/1695 train_time:67311ms step_avg:96.85ms -step:696/1695 train_time:67406ms step_avg:96.85ms -step:697/1695 train_time:67501ms step_avg:96.84ms -step:698/1695 train_time:67597ms step_avg:96.84ms -step:699/1695 train_time:67692ms step_avg:96.84ms -step:700/1695 train_time:67787ms step_avg:96.84ms -step:701/1695 train_time:67886ms step_avg:96.84ms -step:702/1695 train_time:67985ms step_avg:96.84ms -step:703/1695 train_time:68082ms step_avg:96.84ms -step:704/1695 train_time:68178ms step_avg:96.84ms -step:705/1695 train_time:68274ms step_avg:96.84ms -step:706/1695 train_time:68370ms step_avg:96.84ms -step:707/1695 train_time:68465ms step_avg:96.84ms -step:708/1695 train_time:68561ms step_avg:96.84ms -step:709/1695 train_time:68656ms step_avg:96.84ms -step:710/1695 train_time:68752ms step_avg:96.83ms -step:711/1695 train_time:68850ms step_avg:96.84ms -step:712/1695 train_time:68946ms step_avg:96.83ms -step:713/1695 train_time:69043ms step_avg:96.83ms -step:714/1695 train_time:69139ms step_avg:96.83ms -step:715/1695 train_time:69235ms step_avg:96.83ms -step:716/1695 train_time:69331ms step_avg:96.83ms -step:717/1695 train_time:69426ms step_avg:96.83ms -step:718/1695 train_time:69522ms step_avg:96.83ms -step:719/1695 train_time:69618ms step_avg:96.83ms -step:720/1695 train_time:69714ms step_avg:96.82ms -step:721/1695 train_time:69811ms step_avg:96.83ms -step:722/1695 train_time:69908ms step_avg:96.83ms -step:723/1695 train_time:70004ms step_avg:96.82ms -step:724/1695 train_time:70101ms step_avg:96.82ms -step:725/1695 train_time:70198ms step_avg:96.83ms -step:726/1695 train_time:70296ms step_avg:96.83ms -step:727/1695 train_time:70392ms step_avg:96.83ms -step:728/1695 train_time:70487ms step_avg:96.82ms -step:729/1695 train_time:70583ms step_avg:96.82ms -step:730/1695 train_time:70679ms step_avg:96.82ms -step:731/1695 train_time:70776ms step_avg:96.82ms -step:732/1695 train_time:70874ms step_avg:96.82ms -step:733/1695 train_time:70970ms step_avg:96.82ms -step:734/1695 train_time:71066ms step_avg:96.82ms -step:735/1695 train_time:71162ms step_avg:96.82ms -step:736/1695 train_time:71258ms step_avg:96.82ms -step:737/1695 train_time:71355ms step_avg:96.82ms -step:738/1695 train_time:71451ms step_avg:96.82ms -step:739/1695 train_time:71546ms step_avg:96.81ms -step:740/1695 train_time:71641ms step_avg:96.81ms -step:741/1695 train_time:71739ms step_avg:96.81ms -step:742/1695 train_time:71837ms step_avg:96.81ms -step:743/1695 train_time:71934ms step_avg:96.82ms -step:744/1695 train_time:72030ms step_avg:96.81ms -step:745/1695 train_time:72125ms step_avg:96.81ms -step:746/1695 train_time:72221ms step_avg:96.81ms -step:747/1695 train_time:72318ms step_avg:96.81ms -step:748/1695 train_time:72414ms step_avg:96.81ms -step:749/1695 train_time:72510ms step_avg:96.81ms -step:750/1695 train_time:72605ms step_avg:96.81ms -step:750/1695 val_loss:3.5663 train_time:72700ms step_avg:96.93ms -step:751/1695 train_time:72724ms step_avg:96.84ms -step:752/1695 train_time:72807ms step_avg:96.82ms -step:753/1695 train_time:72904ms step_avg:96.82ms -step:754/1695 train_time:73002ms step_avg:96.82ms -step:755/1695 train_time:73098ms step_avg:96.82ms -step:756/1695 train_time:73192ms step_avg:96.82ms -step:757/1695 train_time:73287ms step_avg:96.81ms -step:758/1695 train_time:73381ms step_avg:96.81ms -step:759/1695 train_time:73476ms step_avg:96.81ms -step:760/1695 train_time:73571ms step_avg:96.80ms -step:761/1695 train_time:73668ms step_avg:96.80ms -step:762/1695 train_time:73766ms step_avg:96.81ms -step:763/1695 train_time:73864ms step_avg:96.81ms -step:764/1695 train_time:73962ms step_avg:96.81ms -step:765/1695 train_time:74059ms step_avg:96.81ms -step:766/1695 train_time:74154ms step_avg:96.81ms -step:767/1695 train_time:74249ms step_avg:96.80ms -step:768/1695 train_time:74344ms step_avg:96.80ms -step:769/1695 train_time:74439ms step_avg:96.80ms -step:770/1695 train_time:74535ms step_avg:96.80ms -step:771/1695 train_time:74630ms step_avg:96.80ms -step:772/1695 train_time:74726ms step_avg:96.80ms -step:773/1695 train_time:74824ms step_avg:96.80ms -step:774/1695 train_time:74921ms step_avg:96.80ms -step:775/1695 train_time:75018ms step_avg:96.80ms -step:776/1695 train_time:75114ms step_avg:96.80ms -step:777/1695 train_time:75209ms step_avg:96.79ms -step:778/1695 train_time:75304ms step_avg:96.79ms -step:779/1695 train_time:75400ms step_avg:96.79ms -step:780/1695 train_time:75496ms step_avg:96.79ms -step:781/1695 train_time:75592ms step_avg:96.79ms -step:782/1695 train_time:75687ms step_avg:96.79ms -step:783/1695 train_time:75784ms step_avg:96.79ms -step:784/1695 train_time:75880ms step_avg:96.79ms -step:785/1695 train_time:75977ms step_avg:96.79ms -step:786/1695 train_time:76073ms step_avg:96.79ms -step:787/1695 train_time:76168ms step_avg:96.78ms -step:788/1695 train_time:76264ms step_avg:96.78ms -step:789/1695 train_time:76359ms step_avg:96.78ms -step:790/1695 train_time:76455ms step_avg:96.78ms -step:791/1695 train_time:76549ms step_avg:96.78ms -step:792/1695 train_time:76645ms step_avg:96.77ms -step:793/1695 train_time:76742ms step_avg:96.77ms -step:794/1695 train_time:76839ms step_avg:96.77ms -step:795/1695 train_time:76935ms step_avg:96.77ms -step:796/1695 train_time:77031ms step_avg:96.77ms -step:797/1695 train_time:77126ms step_avg:96.77ms -step:798/1695 train_time:77224ms step_avg:96.77ms -step:799/1695 train_time:77320ms step_avg:96.77ms -step:800/1695 train_time:77416ms step_avg:96.77ms -step:801/1695 train_time:77511ms step_avg:96.77ms -step:802/1695 train_time:77607ms step_avg:96.77ms -step:803/1695 train_time:77702ms step_avg:96.76ms -step:804/1695 train_time:77797ms step_avg:96.76ms -step:805/1695 train_time:77893ms step_avg:96.76ms -step:806/1695 train_time:77989ms step_avg:96.76ms -step:807/1695 train_time:78085ms step_avg:96.76ms -step:808/1695 train_time:78181ms step_avg:96.76ms -step:809/1695 train_time:78276ms step_avg:96.76ms -step:810/1695 train_time:78371ms step_avg:96.75ms -step:811/1695 train_time:78467ms step_avg:96.75ms -step:812/1695 train_time:78562ms step_avg:96.75ms -step:813/1695 train_time:78657ms step_avg:96.75ms -step:814/1695 train_time:78752ms step_avg:96.75ms -step:815/1695 train_time:78848ms step_avg:96.75ms -step:816/1695 train_time:78945ms step_avg:96.75ms -step:817/1695 train_time:79042ms step_avg:96.75ms -step:818/1695 train_time:79138ms step_avg:96.75ms -step:819/1695 train_time:79234ms step_avg:96.74ms -step:820/1695 train_time:79329ms step_avg:96.74ms -step:821/1695 train_time:79425ms step_avg:96.74ms -step:822/1695 train_time:79520ms step_avg:96.74ms -step:823/1695 train_time:79616ms step_avg:96.74ms -step:824/1695 train_time:79711ms step_avg:96.74ms -step:825/1695 train_time:79807ms step_avg:96.74ms -step:826/1695 train_time:79903ms step_avg:96.73ms -step:827/1695 train_time:79998ms step_avg:96.73ms -step:828/1695 train_time:80094ms step_avg:96.73ms -step:829/1695 train_time:80190ms step_avg:96.73ms -step:830/1695 train_time:80286ms step_avg:96.73ms -step:831/1695 train_time:80382ms step_avg:96.73ms -step:832/1695 train_time:80477ms step_avg:96.73ms -step:833/1695 train_time:80572ms step_avg:96.73ms -step:834/1695 train_time:80668ms step_avg:96.72ms -step:835/1695 train_time:80763ms step_avg:96.72ms -step:836/1695 train_time:80859ms step_avg:96.72ms -step:837/1695 train_time:80955ms step_avg:96.72ms -step:838/1695 train_time:81050ms step_avg:96.72ms -step:839/1695 train_time:81146ms step_avg:96.72ms -step:840/1695 train_time:81243ms step_avg:96.72ms -step:841/1695 train_time:81339ms step_avg:96.72ms -step:842/1695 train_time:81435ms step_avg:96.72ms -step:843/1695 train_time:81530ms step_avg:96.71ms -step:844/1695 train_time:81625ms step_avg:96.71ms -step:845/1695 train_time:81722ms step_avg:96.71ms -step:846/1695 train_time:81818ms step_avg:96.71ms -step:847/1695 train_time:81915ms step_avg:96.71ms -step:848/1695 train_time:82011ms step_avg:96.71ms -step:849/1695 train_time:82106ms step_avg:96.71ms -step:850/1695 train_time:82202ms step_avg:96.71ms -step:851/1695 train_time:82298ms step_avg:96.71ms -step:852/1695 train_time:82393ms step_avg:96.70ms -step:853/1695 train_time:82488ms step_avg:96.70ms -step:854/1695 train_time:82583ms step_avg:96.70ms -step:855/1695 train_time:82679ms step_avg:96.70ms -step:856/1695 train_time:82775ms step_avg:96.70ms -step:857/1695 train_time:82870ms step_avg:96.70ms -step:858/1695 train_time:82966ms step_avg:96.70ms -step:859/1695 train_time:83063ms step_avg:96.70ms -step:860/1695 train_time:83159ms step_avg:96.70ms -step:861/1695 train_time:83255ms step_avg:96.70ms -step:862/1695 train_time:83350ms step_avg:96.69ms -step:863/1695 train_time:83679ms step_avg:96.96ms -step:864/1695 train_time:83862ms step_avg:97.06ms -step:865/1695 train_time:83955ms step_avg:97.06ms -step:866/1695 train_time:84050ms step_avg:97.06ms -step:867/1695 train_time:84145ms step_avg:97.05ms -step:868/1695 train_time:84240ms step_avg:97.05ms -step:869/1695 train_time:84336ms step_avg:97.05ms -step:870/1695 train_time:84431ms step_avg:97.05ms -step:871/1695 train_time:84525ms step_avg:97.04ms -step:872/1695 train_time:84620ms step_avg:97.04ms -step:873/1695 train_time:84718ms step_avg:97.04ms -step:874/1695 train_time:84818ms step_avg:97.05ms -step:875/1695 train_time:84917ms step_avg:97.05ms -step:875/1695 val_loss:3.5235 train_time:85011ms step_avg:97.16ms -step:876/1695 train_time:85037ms step_avg:97.07ms -step:877/1695 train_time:85116ms step_avg:97.05ms -step:878/1695 train_time:85213ms step_avg:97.05ms -step:879/1695 train_time:85309ms step_avg:97.05ms -step:880/1695 train_time:85404ms step_avg:97.05ms -step:881/1695 train_time:85499ms step_avg:97.05ms -step:882/1695 train_time:85594ms step_avg:97.05ms -step:883/1695 train_time:85690ms step_avg:97.04ms -step:884/1695 train_time:85785ms step_avg:97.04ms -step:885/1695 train_time:85879ms step_avg:97.04ms -step:886/1695 train_time:85976ms step_avg:97.04ms -step:887/1695 train_time:86075ms step_avg:97.04ms -step:888/1695 train_time:86174ms step_avg:97.04ms -step:889/1695 train_time:86271ms step_avg:97.04ms -step:890/1695 train_time:86367ms step_avg:97.04ms -step:891/1695 train_time:86462ms step_avg:97.04ms -step:892/1695 train_time:86557ms step_avg:97.04ms -step:893/1695 train_time:86653ms step_avg:97.04ms -step:894/1695 train_time:86749ms step_avg:97.04ms -step:895/1695 train_time:86845ms step_avg:97.03ms -step:896/1695 train_time:86940ms step_avg:97.03ms -step:897/1695 train_time:87038ms step_avg:97.03ms -step:898/1695 train_time:87136ms step_avg:97.03ms -step:899/1695 train_time:87234ms step_avg:97.03ms -step:900/1695 train_time:87331ms step_avg:97.03ms -step:901/1695 train_time:87427ms step_avg:97.03ms -step:902/1695 train_time:87522ms step_avg:97.03ms -step:903/1695 train_time:87617ms step_avg:97.03ms -step:904/1695 train_time:87714ms step_avg:97.03ms -step:905/1695 train_time:87810ms step_avg:97.03ms -step:906/1695 train_time:87905ms step_avg:97.03ms -step:907/1695 train_time:88002ms step_avg:97.03ms -step:908/1695 train_time:88099ms step_avg:97.03ms -step:909/1695 train_time:88197ms step_avg:97.03ms -step:910/1695 train_time:88295ms step_avg:97.03ms -step:911/1695 train_time:88392ms step_avg:97.03ms -step:912/1695 train_time:88487ms step_avg:97.03ms -step:913/1695 train_time:88582ms step_avg:97.02ms -step:914/1695 train_time:88678ms step_avg:97.02ms -step:915/1695 train_time:88774ms step_avg:97.02ms -step:916/1695 train_time:88870ms step_avg:97.02ms -step:917/1695 train_time:88965ms step_avg:97.02ms -step:918/1695 train_time:89060ms step_avg:97.02ms -step:919/1695 train_time:89157ms step_avg:97.02ms -step:920/1695 train_time:89255ms step_avg:97.02ms -step:921/1695 train_time:89353ms step_avg:97.02ms -step:922/1695 train_time:89451ms step_avg:97.02ms -step:923/1695 train_time:89547ms step_avg:97.02ms -step:924/1695 train_time:89642ms step_avg:97.01ms -step:925/1695 train_time:89738ms step_avg:97.01ms -step:926/1695 train_time:89833ms step_avg:97.01ms -step:927/1695 train_time:89929ms step_avg:97.01ms -step:928/1695 train_time:90025ms step_avg:97.01ms -step:929/1695 train_time:90120ms step_avg:97.01ms -step:930/1695 train_time:90216ms step_avg:97.01ms -step:931/1695 train_time:90313ms step_avg:97.01ms -step:932/1695 train_time:90410ms step_avg:97.01ms -step:933/1695 train_time:90506ms step_avg:97.01ms -step:934/1695 train_time:90602ms step_avg:97.00ms -step:935/1695 train_time:90697ms step_avg:97.00ms -step:936/1695 train_time:90794ms step_avg:97.00ms -step:937/1695 train_time:90890ms step_avg:97.00ms -step:938/1695 train_time:90985ms step_avg:97.00ms -step:939/1695 train_time:91081ms step_avg:97.00ms -step:940/1695 train_time:91177ms step_avg:97.00ms -step:941/1695 train_time:91274ms step_avg:97.00ms -step:942/1695 train_time:91371ms step_avg:97.00ms -step:943/1695 train_time:91468ms step_avg:97.00ms -step:944/1695 train_time:91563ms step_avg:97.00ms -step:945/1695 train_time:91660ms step_avg:96.99ms -step:946/1695 train_time:91756ms step_avg:96.99ms -step:947/1695 train_time:91855ms step_avg:97.00ms -step:948/1695 train_time:91952ms step_avg:97.00ms -step:949/1695 train_time:92048ms step_avg:96.99ms -step:950/1695 train_time:92144ms step_avg:96.99ms -step:951/1695 train_time:92239ms step_avg:96.99ms -step:952/1695 train_time:92336ms step_avg:96.99ms -step:953/1695 train_time:92433ms step_avg:96.99ms -step:954/1695 train_time:92529ms step_avg:96.99ms -step:955/1695 train_time:92625ms step_avg:96.99ms -step:956/1695 train_time:92721ms step_avg:96.99ms -step:957/1695 train_time:92817ms step_avg:96.99ms -step:958/1695 train_time:92914ms step_avg:96.99ms -step:959/1695 train_time:93011ms step_avg:96.99ms -step:960/1695 train_time:93107ms step_avg:96.99ms -step:961/1695 train_time:93203ms step_avg:96.99ms -step:962/1695 train_time:93299ms step_avg:96.98ms -step:963/1695 train_time:93395ms step_avg:96.98ms -step:964/1695 train_time:93491ms step_avg:96.98ms -step:965/1695 train_time:93587ms step_avg:96.98ms -step:966/1695 train_time:93682ms step_avg:96.98ms -step:967/1695 train_time:93779ms step_avg:96.98ms -step:968/1695 train_time:93874ms step_avg:96.98ms -step:969/1695 train_time:93970ms step_avg:96.98ms -step:970/1695 train_time:94066ms step_avg:96.98ms -step:971/1695 train_time:94162ms step_avg:96.97ms -step:972/1695 train_time:94259ms step_avg:96.97ms -step:973/1695 train_time:94356ms step_avg:96.97ms -step:974/1695 train_time:94452ms step_avg:96.97ms -step:975/1695 train_time:94549ms step_avg:96.97ms -step:976/1695 train_time:94643ms step_avg:96.97ms -step:977/1695 train_time:94739ms step_avg:96.97ms -step:978/1695 train_time:94835ms step_avg:96.97ms -step:979/1695 train_time:94932ms step_avg:96.97ms -step:980/1695 train_time:95028ms step_avg:96.97ms -step:981/1695 train_time:95124ms step_avg:96.97ms -step:982/1695 train_time:95219ms step_avg:96.96ms -step:983/1695 train_time:95316ms step_avg:96.96ms -step:984/1695 train_time:95412ms step_avg:96.96ms -step:985/1695 train_time:95508ms step_avg:96.96ms -step:986/1695 train_time:95603ms step_avg:96.96ms -step:987/1695 train_time:95699ms step_avg:96.96ms -step:988/1695 train_time:95795ms step_avg:96.96ms -step:989/1695 train_time:95891ms step_avg:96.96ms -step:990/1695 train_time:95987ms step_avg:96.96ms -step:991/1695 train_time:96083ms step_avg:96.96ms -step:992/1695 train_time:96178ms step_avg:96.95ms -step:993/1695 train_time:96274ms step_avg:96.95ms -step:994/1695 train_time:96370ms step_avg:96.95ms -step:995/1695 train_time:96465ms step_avg:96.95ms -step:996/1695 train_time:96560ms step_avg:96.95ms -step:997/1695 train_time:96657ms step_avg:96.95ms -step:998/1695 train_time:96753ms step_avg:96.95ms -step:999/1695 train_time:96850ms step_avg:96.95ms -step:1000/1695 train_time:96947ms step_avg:96.95ms -step:1000/1695 val_loss:3.4841 train_time:97040ms step_avg:97.04ms -step:1001/1695 train_time:97064ms step_avg:96.97ms -step:1002/1695 train_time:97146ms step_avg:96.95ms -step:1003/1695 train_time:97243ms step_avg:96.95ms -step:1004/1695 train_time:97339ms step_avg:96.95ms -step:1005/1695 train_time:97434ms step_avg:96.95ms -step:1006/1695 train_time:97530ms step_avg:96.95ms -step:1007/1695 train_time:97624ms step_avg:96.95ms -step:1008/1695 train_time:97719ms step_avg:96.94ms -step:1009/1695 train_time:97815ms step_avg:96.94ms -step:1010/1695 train_time:97910ms step_avg:96.94ms -step:1011/1695 train_time:98007ms step_avg:96.94ms -step:1012/1695 train_time:98104ms step_avg:96.94ms -step:1013/1695 train_time:98201ms step_avg:96.94ms -step:1014/1695 train_time:98298ms step_avg:96.94ms -step:1015/1695 train_time:98395ms step_avg:96.94ms -step:1016/1695 train_time:98490ms step_avg:96.94ms -step:1017/1695 train_time:98586ms step_avg:96.94ms -step:1018/1695 train_time:98682ms step_avg:96.94ms -step:1019/1695 train_time:98777ms step_avg:96.93ms -step:1020/1695 train_time:98872ms step_avg:96.93ms -step:1021/1695 train_time:98967ms step_avg:96.93ms -step:1022/1695 train_time:99063ms step_avg:96.93ms -step:1023/1695 train_time:99160ms step_avg:96.93ms -step:1024/1695 train_time:99258ms step_avg:96.93ms -step:1025/1695 train_time:99355ms step_avg:96.93ms -step:1026/1695 train_time:99451ms step_avg:96.93ms -step:1027/1695 train_time:99547ms step_avg:96.93ms -step:1028/1695 train_time:99643ms step_avg:96.93ms -step:1029/1695 train_time:99738ms step_avg:96.93ms -step:1030/1695 train_time:99833ms step_avg:96.93ms -step:1031/1695 train_time:99930ms step_avg:96.93ms -step:1032/1695 train_time:100026ms step_avg:96.92ms -step:1033/1695 train_time:100122ms step_avg:96.92ms -step:1034/1695 train_time:100218ms step_avg:96.92ms -step:1035/1695 train_time:100315ms step_avg:96.92ms -step:1036/1695 train_time:100647ms step_avg:97.15ms -step:1037/1695 train_time:100826ms step_avg:97.23ms -step:1038/1695 train_time:100920ms step_avg:97.22ms -step:1039/1695 train_time:101015ms step_avg:97.22ms -step:1040/1695 train_time:101110ms step_avg:97.22ms -step:1041/1695 train_time:101204ms step_avg:97.22ms -step:1042/1695 train_time:101299ms step_avg:97.22ms -step:1043/1695 train_time:101394ms step_avg:97.21ms -step:1044/1695 train_time:101489ms step_avg:97.21ms -step:1045/1695 train_time:101584ms step_avg:97.21ms -step:1046/1695 train_time:101681ms step_avg:97.21ms -step:1047/1695 train_time:101783ms step_avg:97.21ms -step:1048/1695 train_time:101883ms step_avg:97.22ms -step:1049/1695 train_time:101980ms step_avg:97.22ms -step:1050/1695 train_time:102077ms step_avg:97.22ms -step:1051/1695 train_time:102173ms step_avg:97.21ms -step:1052/1695 train_time:102268ms step_avg:97.21ms -step:1053/1695 train_time:102362ms step_avg:97.21ms -step:1054/1695 train_time:102457ms step_avg:97.21ms -step:1055/1695 train_time:102553ms step_avg:97.21ms -step:1056/1695 train_time:102650ms step_avg:97.21ms -step:1057/1695 train_time:102747ms step_avg:97.21ms -step:1058/1695 train_time:102844ms step_avg:97.21ms -step:1059/1695 train_time:102941ms step_avg:97.21ms -step:1060/1695 train_time:103037ms step_avg:97.20ms -step:1061/1695 train_time:103134ms step_avg:97.20ms -step:1062/1695 train_time:103230ms step_avg:97.20ms -step:1063/1695 train_time:103325ms step_avg:97.20ms -step:1064/1695 train_time:103421ms step_avg:97.20ms -step:1065/1695 train_time:103516ms step_avg:97.20ms -step:1066/1695 train_time:103612ms step_avg:97.20ms -step:1067/1695 train_time:103709ms step_avg:97.20ms -step:1068/1695 train_time:103806ms step_avg:97.20ms -step:1069/1695 train_time:103901ms step_avg:97.19ms -step:1070/1695 train_time:103997ms step_avg:97.19ms -step:1071/1695 train_time:104094ms step_avg:97.19ms -step:1072/1695 train_time:104190ms step_avg:97.19ms -step:1073/1695 train_time:104285ms step_avg:97.19ms -step:1074/1695 train_time:104381ms step_avg:97.19ms -step:1075/1695 train_time:104476ms step_avg:97.19ms -step:1076/1695 train_time:104572ms step_avg:97.19ms -step:1077/1695 train_time:104668ms step_avg:97.18ms -step:1078/1695 train_time:104764ms step_avg:97.18ms -step:1079/1695 train_time:104860ms step_avg:97.18ms -step:1080/1695 train_time:104956ms step_avg:97.18ms -step:1081/1695 train_time:105052ms step_avg:97.18ms -step:1082/1695 train_time:105149ms step_avg:97.18ms -step:1083/1695 train_time:105246ms step_avg:97.18ms -step:1084/1695 train_time:105341ms step_avg:97.18ms -step:1085/1695 train_time:105437ms step_avg:97.18ms -step:1086/1695 train_time:105533ms step_avg:97.18ms -step:1087/1695 train_time:105630ms step_avg:97.18ms -step:1088/1695 train_time:105726ms step_avg:97.18ms -step:1089/1695 train_time:105822ms step_avg:97.17ms -step:1090/1695 train_time:105918ms step_avg:97.17ms -step:1091/1695 train_time:106014ms step_avg:97.17ms -step:1092/1695 train_time:106110ms step_avg:97.17ms -step:1093/1695 train_time:106206ms step_avg:97.17ms -step:1094/1695 train_time:106301ms step_avg:97.17ms -step:1095/1695 train_time:106397ms step_avg:97.17ms -step:1096/1695 train_time:106493ms step_avg:97.16ms -step:1097/1695 train_time:106589ms step_avg:97.16ms -step:1098/1695 train_time:106685ms step_avg:97.16ms -step:1099/1695 train_time:106781ms step_avg:97.16ms -step:1100/1695 train_time:106878ms step_avg:97.16ms -step:1101/1695 train_time:106974ms step_avg:97.16ms -step:1102/1695 train_time:107070ms step_avg:97.16ms -step:1103/1695 train_time:107166ms step_avg:97.16ms -step:1104/1695 train_time:107261ms step_avg:97.16ms -step:1105/1695 train_time:107358ms step_avg:97.16ms -step:1106/1695 train_time:107455ms step_avg:97.16ms -step:1107/1695 train_time:107551ms step_avg:97.16ms -step:1108/1695 train_time:107648ms step_avg:97.16ms -step:1109/1695 train_time:107743ms step_avg:97.15ms -step:1110/1695 train_time:107839ms step_avg:97.15ms -step:1111/1695 train_time:107935ms step_avg:97.15ms -step:1112/1695 train_time:108031ms step_avg:97.15ms -step:1113/1695 train_time:108127ms step_avg:97.15ms -step:1114/1695 train_time:108223ms step_avg:97.15ms -step:1115/1695 train_time:108319ms step_avg:97.15ms -step:1116/1695 train_time:108416ms step_avg:97.15ms -step:1117/1695 train_time:108512ms step_avg:97.15ms -step:1118/1695 train_time:108608ms step_avg:97.15ms -step:1119/1695 train_time:108704ms step_avg:97.14ms -step:1120/1695 train_time:108799ms step_avg:97.14ms -step:1121/1695 train_time:108896ms step_avg:97.14ms -step:1122/1695 train_time:108993ms step_avg:97.14ms -step:1123/1695 train_time:109090ms step_avg:97.14ms -step:1124/1695 train_time:109186ms step_avg:97.14ms -step:1125/1695 train_time:109281ms step_avg:97.14ms -step:1125/1695 val_loss:3.4352 train_time:109375ms step_avg:97.22ms -step:1126/1695 train_time:109400ms step_avg:97.16ms -step:1127/1695 train_time:109483ms step_avg:97.15ms -step:1128/1695 train_time:109580ms step_avg:97.15ms -step:1129/1695 train_time:109676ms step_avg:97.14ms -step:1130/1695 train_time:109771ms step_avg:97.14ms -step:1131/1695 train_time:109866ms step_avg:97.14ms -step:1132/1695 train_time:109960ms step_avg:97.14ms -step:1133/1695 train_time:110056ms step_avg:97.14ms -step:1134/1695 train_time:110153ms step_avg:97.14ms -step:1135/1695 train_time:110251ms step_avg:97.14ms -step:1136/1695 train_time:110349ms step_avg:97.14ms -step:1137/1695 train_time:110450ms step_avg:97.14ms -step:1138/1695 train_time:110549ms step_avg:97.14ms -step:1139/1695 train_time:110648ms step_avg:97.14ms -step:1140/1695 train_time:110745ms step_avg:97.14ms -step:1141/1695 train_time:110841ms step_avg:97.14ms -step:1142/1695 train_time:110938ms step_avg:97.14ms -step:1143/1695 train_time:111036ms step_avg:97.14ms -step:1144/1695 train_time:111133ms step_avg:97.14ms -step:1145/1695 train_time:111230ms step_avg:97.14ms -step:1146/1695 train_time:111328ms step_avg:97.15ms -step:1147/1695 train_time:111427ms step_avg:97.15ms -step:1148/1695 train_time:111527ms step_avg:97.15ms -step:1149/1695 train_time:111626ms step_avg:97.15ms -step:1150/1695 train_time:111724ms step_avg:97.15ms -step:1151/1695 train_time:111820ms step_avg:97.15ms -step:1152/1695 train_time:111918ms step_avg:97.15ms -step:1153/1695 train_time:112015ms step_avg:97.15ms -step:1154/1695 train_time:112112ms step_avg:97.15ms -step:1155/1695 train_time:112210ms step_avg:97.15ms -step:1156/1695 train_time:112307ms step_avg:97.15ms -step:1157/1695 train_time:112405ms step_avg:97.15ms -step:1158/1695 train_time:112503ms step_avg:97.15ms -step:1159/1695 train_time:112602ms step_avg:97.15ms -step:1160/1695 train_time:112701ms step_avg:97.16ms -step:1161/1695 train_time:112798ms step_avg:97.16ms -step:1162/1695 train_time:112895ms step_avg:97.16ms -step:1163/1695 train_time:112993ms step_avg:97.16ms -step:1164/1695 train_time:113089ms step_avg:97.16ms -step:1165/1695 train_time:113186ms step_avg:97.16ms -step:1166/1695 train_time:113283ms step_avg:97.16ms -step:1167/1695 train_time:113381ms step_avg:97.16ms -step:1168/1695 train_time:113479ms step_avg:97.16ms -step:1169/1695 train_time:113578ms step_avg:97.16ms -step:1170/1695 train_time:113678ms step_avg:97.16ms -step:1171/1695 train_time:113777ms step_avg:97.16ms -step:1172/1695 train_time:113875ms step_avg:97.16ms -step:1173/1695 train_time:113973ms step_avg:97.16ms -step:1174/1695 train_time:114071ms step_avg:97.16ms -step:1175/1695 train_time:114168ms step_avg:97.16ms -step:1176/1695 train_time:114265ms step_avg:97.16ms -step:1177/1695 train_time:114362ms step_avg:97.16ms -step:1178/1695 train_time:114460ms step_avg:97.16ms -step:1179/1695 train_time:114558ms step_avg:97.17ms -step:1180/1695 train_time:114659ms step_avg:97.17ms -step:1181/1695 train_time:114757ms step_avg:97.17ms -step:1182/1695 train_time:114856ms step_avg:97.17ms -step:1183/1695 train_time:114955ms step_avg:97.17ms -step:1184/1695 train_time:115054ms step_avg:97.17ms -step:1185/1695 train_time:115152ms step_avg:97.17ms -step:1186/1695 train_time:115250ms step_avg:97.18ms -step:1187/1695 train_time:115348ms step_avg:97.18ms -step:1188/1695 train_time:115445ms step_avg:97.18ms -step:1189/1695 train_time:115543ms step_avg:97.18ms -step:1190/1695 train_time:115640ms step_avg:97.18ms -step:1191/1695 train_time:115739ms step_avg:97.18ms -step:1192/1695 train_time:115838ms step_avg:97.18ms -step:1193/1695 train_time:115938ms step_avg:97.18ms -step:1194/1695 train_time:116038ms step_avg:97.18ms -step:1195/1695 train_time:116138ms step_avg:97.19ms -step:1196/1695 train_time:116236ms step_avg:97.19ms -step:1197/1695 train_time:116336ms step_avg:97.19ms -step:1198/1695 train_time:116435ms step_avg:97.19ms -step:1199/1695 train_time:116533ms step_avg:97.19ms -step:1200/1695 train_time:116632ms step_avg:97.19ms -step:1201/1695 train_time:116730ms step_avg:97.19ms -step:1202/1695 train_time:116828ms step_avg:97.19ms -step:1203/1695 train_time:116927ms step_avg:97.20ms -step:1204/1695 train_time:117024ms step_avg:97.20ms -step:1205/1695 train_time:117122ms step_avg:97.20ms -step:1206/1695 train_time:117220ms step_avg:97.20ms -step:1207/1695 train_time:117318ms step_avg:97.20ms -step:1208/1695 train_time:117661ms step_avg:97.40ms -step:1209/1695 train_time:117847ms step_avg:97.47ms -step:1210/1695 train_time:117942ms step_avg:97.47ms -step:1211/1695 train_time:118039ms step_avg:97.47ms -step:1212/1695 train_time:118136ms step_avg:97.47ms -step:1213/1695 train_time:118233ms step_avg:97.47ms -step:1214/1695 train_time:118330ms step_avg:97.47ms -step:1215/1695 train_time:118426ms step_avg:97.47ms -step:1216/1695 train_time:118522ms step_avg:97.47ms -step:1217/1695 train_time:118618ms step_avg:97.47ms -step:1218/1695 train_time:118722ms step_avg:97.47ms -step:1219/1695 train_time:118826ms step_avg:97.48ms -step:1220/1695 train_time:118924ms step_avg:97.48ms -step:1221/1695 train_time:119021ms step_avg:97.48ms -step:1222/1695 train_time:119118ms step_avg:97.48ms -step:1223/1695 train_time:119215ms step_avg:97.48ms -step:1224/1695 train_time:119312ms step_avg:97.48ms -step:1225/1695 train_time:119409ms step_avg:97.48ms -step:1226/1695 train_time:119505ms step_avg:97.48ms -step:1227/1695 train_time:119602ms step_avg:97.48ms -step:1228/1695 train_time:119701ms step_avg:97.48ms -step:1229/1695 train_time:119802ms step_avg:97.48ms -step:1230/1695 train_time:119902ms step_avg:97.48ms -step:1231/1695 train_time:120000ms step_avg:97.48ms -step:1232/1695 train_time:120098ms step_avg:97.48ms -step:1233/1695 train_time:120196ms step_avg:97.48ms -step:1234/1695 train_time:120294ms step_avg:97.48ms -step:1235/1695 train_time:120391ms step_avg:97.48ms -step:1236/1695 train_time:120488ms step_avg:97.48ms -step:1237/1695 train_time:120585ms step_avg:97.48ms -step:1238/1695 train_time:120683ms step_avg:97.48ms -step:1239/1695 train_time:120781ms step_avg:97.48ms -step:1240/1695 train_time:120880ms step_avg:97.48ms -step:1241/1695 train_time:120978ms step_avg:97.48ms -step:1242/1695 train_time:121077ms step_avg:97.49ms -step:1243/1695 train_time:121174ms step_avg:97.48ms -step:1244/1695 train_time:121272ms step_avg:97.49ms -step:1245/1695 train_time:121369ms step_avg:97.48ms -step:1246/1695 train_time:121466ms step_avg:97.48ms -step:1247/1695 train_time:121562ms step_avg:97.48ms -step:1248/1695 train_time:121661ms step_avg:97.48ms -step:1249/1695 train_time:121759ms step_avg:97.49ms -step:1250/1695 train_time:121859ms step_avg:97.49ms -step:1250/1695 val_loss:3.3872 train_time:121956ms step_avg:97.56ms -step:1251/1695 train_time:121980ms step_avg:97.51ms -step:1252/1695 train_time:122061ms step_avg:97.49ms -step:1253/1695 train_time:122158ms step_avg:97.49ms -step:1254/1695 train_time:122255ms step_avg:97.49ms -step:1255/1695 train_time:122351ms step_avg:97.49ms -step:1256/1695 train_time:122448ms step_avg:97.49ms -step:1257/1695 train_time:122545ms step_avg:97.49ms -step:1258/1695 train_time:122641ms step_avg:97.49ms -step:1259/1695 train_time:122737ms step_avg:97.49ms -step:1260/1695 train_time:122835ms step_avg:97.49ms -step:1261/1695 train_time:122937ms step_avg:97.49ms -step:1262/1695 train_time:123037ms step_avg:97.49ms -step:1263/1695 train_time:123135ms step_avg:97.49ms -step:1264/1695 train_time:123233ms step_avg:97.49ms -step:1265/1695 train_time:123330ms step_avg:97.49ms -step:1266/1695 train_time:123427ms step_avg:97.49ms -step:1267/1695 train_time:123524ms step_avg:97.49ms -step:1268/1695 train_time:123621ms step_avg:97.49ms -step:1269/1695 train_time:123718ms step_avg:97.49ms -step:1270/1695 train_time:123816ms step_avg:97.49ms -step:1271/1695 train_time:123915ms step_avg:97.49ms -step:1272/1695 train_time:124014ms step_avg:97.50ms -step:1273/1695 train_time:124113ms step_avg:97.50ms -step:1274/1695 train_time:124210ms step_avg:97.50ms -step:1275/1695 train_time:124308ms step_avg:97.50ms -step:1276/1695 train_time:124406ms step_avg:97.50ms -step:1277/1695 train_time:124503ms step_avg:97.50ms -step:1278/1695 train_time:124600ms step_avg:97.50ms -step:1279/1695 train_time:124697ms step_avg:97.50ms -step:1280/1695 train_time:124795ms step_avg:97.50ms -step:1281/1695 train_time:124893ms step_avg:97.50ms -step:1282/1695 train_time:124991ms step_avg:97.50ms -step:1283/1695 train_time:125089ms step_avg:97.50ms -step:1284/1695 train_time:125187ms step_avg:97.50ms -step:1285/1695 train_time:125285ms step_avg:97.50ms -step:1286/1695 train_time:125383ms step_avg:97.50ms -step:1287/1695 train_time:125480ms step_avg:97.50ms -step:1288/1695 train_time:125578ms step_avg:97.50ms -step:1289/1695 train_time:125675ms step_avg:97.50ms -step:1290/1695 train_time:125772ms step_avg:97.50ms -step:1291/1695 train_time:125870ms step_avg:97.50ms -step:1292/1695 train_time:125968ms step_avg:97.50ms -step:1293/1695 train_time:126066ms step_avg:97.50ms -step:1294/1695 train_time:126165ms step_avg:97.50ms -step:1295/1695 train_time:126264ms step_avg:97.50ms -step:1296/1695 train_time:126364ms step_avg:97.50ms -step:1297/1695 train_time:126462ms step_avg:97.50ms -step:1298/1695 train_time:126560ms step_avg:97.50ms -step:1299/1695 train_time:126658ms step_avg:97.50ms -step:1300/1695 train_time:126756ms step_avg:97.50ms -step:1301/1695 train_time:126855ms step_avg:97.51ms -step:1302/1695 train_time:126952ms step_avg:97.51ms -step:1303/1695 train_time:127049ms step_avg:97.50ms -step:1304/1695 train_time:127147ms step_avg:97.51ms -step:1305/1695 train_time:127244ms step_avg:97.51ms -step:1306/1695 train_time:127343ms step_avg:97.51ms -step:1307/1695 train_time:127440ms step_avg:97.51ms -step:1308/1695 train_time:127539ms step_avg:97.51ms -step:1309/1695 train_time:127637ms step_avg:97.51ms -step:1310/1695 train_time:127735ms step_avg:97.51ms -step:1311/1695 train_time:127834ms step_avg:97.51ms -step:1312/1695 train_time:127931ms step_avg:97.51ms -step:1313/1695 train_time:128028ms step_avg:97.51ms -step:1314/1695 train_time:128126ms step_avg:97.51ms -step:1315/1695 train_time:128224ms step_avg:97.51ms -step:1316/1695 train_time:128322ms step_avg:97.51ms -step:1317/1695 train_time:128420ms step_avg:97.51ms -step:1318/1695 train_time:128518ms step_avg:97.51ms -step:1319/1695 train_time:128616ms step_avg:97.51ms -step:1320/1695 train_time:128714ms step_avg:97.51ms -step:1321/1695 train_time:128813ms step_avg:97.51ms -step:1322/1695 train_time:128910ms step_avg:97.51ms -step:1323/1695 train_time:129008ms step_avg:97.51ms -step:1324/1695 train_time:129106ms step_avg:97.51ms -step:1325/1695 train_time:129204ms step_avg:97.51ms -step:1326/1695 train_time:129303ms step_avg:97.51ms -step:1327/1695 train_time:129401ms step_avg:97.51ms -step:1328/1695 train_time:129499ms step_avg:97.51ms -step:1329/1695 train_time:129597ms step_avg:97.51ms -step:1330/1695 train_time:129695ms step_avg:97.51ms -step:1331/1695 train_time:129792ms step_avg:97.51ms -step:1332/1695 train_time:129890ms step_avg:97.52ms -step:1333/1695 train_time:129988ms step_avg:97.52ms -step:1334/1695 train_time:130085ms step_avg:97.51ms -step:1335/1695 train_time:130182ms step_avg:97.51ms -step:1336/1695 train_time:130281ms step_avg:97.52ms -step:1337/1695 train_time:130381ms step_avg:97.52ms -step:1338/1695 train_time:130478ms step_avg:97.52ms -step:1339/1695 train_time:130577ms step_avg:97.52ms -step:1340/1695 train_time:130674ms step_avg:97.52ms -step:1341/1695 train_time:130773ms step_avg:97.52ms -step:1342/1695 train_time:130870ms step_avg:97.52ms -step:1343/1695 train_time:130967ms step_avg:97.52ms -step:1344/1695 train_time:131063ms step_avg:97.52ms -step:1345/1695 train_time:131161ms step_avg:97.52ms -step:1346/1695 train_time:131259ms step_avg:97.52ms -step:1347/1695 train_time:131357ms step_avg:97.52ms -step:1348/1695 train_time:131455ms step_avg:97.52ms -step:1349/1695 train_time:131553ms step_avg:97.52ms -step:1350/1695 train_time:131651ms step_avg:97.52ms -step:1351/1695 train_time:131749ms step_avg:97.52ms -step:1352/1695 train_time:131847ms step_avg:97.52ms -step:1353/1695 train_time:131946ms step_avg:97.52ms -step:1354/1695 train_time:132044ms step_avg:97.52ms -step:1355/1695 train_time:132142ms step_avg:97.52ms -step:1356/1695 train_time:132239ms step_avg:97.52ms -step:1357/1695 train_time:132336ms step_avg:97.52ms -step:1358/1695 train_time:132434ms step_avg:97.52ms -step:1359/1695 train_time:132532ms step_avg:97.52ms -step:1360/1695 train_time:132629ms step_avg:97.52ms -step:1361/1695 train_time:132727ms step_avg:97.52ms -step:1362/1695 train_time:132825ms step_avg:97.52ms -step:1363/1695 train_time:132925ms step_avg:97.52ms -step:1364/1695 train_time:133023ms step_avg:97.52ms -step:1365/1695 train_time:133120ms step_avg:97.52ms -step:1366/1695 train_time:133218ms step_avg:97.52ms -step:1367/1695 train_time:133316ms step_avg:97.52ms -step:1368/1695 train_time:133413ms step_avg:97.52ms -step:1369/1695 train_time:133511ms step_avg:97.52ms -step:1370/1695 train_time:133609ms step_avg:97.52ms -step:1371/1695 train_time:133707ms step_avg:97.53ms -step:1372/1695 train_time:133805ms step_avg:97.53ms -step:1373/1695 train_time:133904ms step_avg:97.53ms -step:1374/1695 train_time:134003ms step_avg:97.53ms -step:1375/1695 train_time:134101ms step_avg:97.53ms -step:1375/1695 val_loss:3.3494 train_time:134197ms step_avg:97.60ms -step:1376/1695 train_time:134222ms step_avg:97.55ms -step:1377/1695 train_time:134308ms step_avg:97.54ms -step:1378/1695 train_time:134406ms step_avg:97.54ms -step:1379/1695 train_time:134504ms step_avg:97.54ms -step:1380/1695 train_time:134602ms step_avg:97.54ms -step:1381/1695 train_time:135056ms step_avg:97.80ms -step:1382/1695 train_time:135131ms step_avg:97.78ms -step:1383/1695 train_time:135227ms step_avg:97.78ms -step:1384/1695 train_time:135324ms step_avg:97.78ms -step:1385/1695 train_time:135420ms step_avg:97.78ms -step:1386/1695 train_time:135517ms step_avg:97.78ms -step:1387/1695 train_time:135613ms step_avg:97.77ms -step:1388/1695 train_time:135709ms step_avg:97.77ms -step:1389/1695 train_time:135806ms step_avg:97.77ms -step:1390/1695 train_time:135905ms step_avg:97.77ms -step:1391/1695 train_time:136009ms step_avg:97.78ms -step:1392/1695 train_time:136109ms step_avg:97.78ms -step:1393/1695 train_time:136207ms step_avg:97.78ms -step:1394/1695 train_time:136304ms step_avg:97.78ms -step:1395/1695 train_time:136402ms step_avg:97.78ms -step:1396/1695 train_time:136499ms step_avg:97.78ms -step:1397/1695 train_time:136596ms step_avg:97.78ms -step:1398/1695 train_time:136692ms step_avg:97.78ms -step:1399/1695 train_time:136788ms step_avg:97.78ms -step:1400/1695 train_time:136887ms step_avg:97.78ms -step:1401/1695 train_time:136985ms step_avg:97.78ms -step:1402/1695 train_time:137085ms step_avg:97.78ms -step:1403/1695 train_time:137185ms step_avg:97.78ms -step:1404/1695 train_time:137283ms step_avg:97.78ms -step:1405/1695 train_time:137381ms step_avg:97.78ms -step:1406/1695 train_time:137479ms step_avg:97.78ms -step:1407/1695 train_time:137577ms step_avg:97.78ms -step:1408/1695 train_time:137674ms step_avg:97.78ms -step:1409/1695 train_time:137771ms step_avg:97.78ms -step:1410/1695 train_time:137869ms step_avg:97.78ms -step:1411/1695 train_time:137967ms step_avg:97.78ms -step:1412/1695 train_time:138066ms step_avg:97.78ms -step:1413/1695 train_time:138164ms step_avg:97.78ms -step:1414/1695 train_time:138263ms step_avg:97.78ms -step:1415/1695 train_time:138361ms step_avg:97.78ms -step:1416/1695 train_time:138459ms step_avg:97.78ms -step:1417/1695 train_time:138557ms step_avg:97.78ms -step:1418/1695 train_time:138655ms step_avg:97.78ms -step:1419/1695 train_time:138753ms step_avg:97.78ms -step:1420/1695 train_time:138850ms step_avg:97.78ms -step:1421/1695 train_time:138947ms step_avg:97.78ms -step:1422/1695 train_time:139045ms step_avg:97.78ms -step:1423/1695 train_time:139143ms step_avg:97.78ms -step:1424/1695 train_time:139242ms step_avg:97.78ms -step:1425/1695 train_time:139340ms step_avg:97.78ms -step:1426/1695 train_time:139438ms step_avg:97.78ms -step:1427/1695 train_time:139535ms step_avg:97.78ms -step:1428/1695 train_time:139633ms step_avg:97.78ms -step:1429/1695 train_time:139730ms step_avg:97.78ms -step:1430/1695 train_time:139828ms step_avg:97.78ms -step:1431/1695 train_time:139926ms step_avg:97.78ms -step:1432/1695 train_time:140024ms step_avg:97.78ms -step:1433/1695 train_time:140122ms step_avg:97.78ms -step:1434/1695 train_time:140222ms step_avg:97.78ms -step:1435/1695 train_time:140319ms step_avg:97.78ms -step:1436/1695 train_time:140418ms step_avg:97.78ms -step:1437/1695 train_time:140516ms step_avg:97.78ms -step:1438/1695 train_time:140613ms step_avg:97.78ms -step:1439/1695 train_time:140711ms step_avg:97.78ms -step:1440/1695 train_time:140808ms step_avg:97.78ms -step:1441/1695 train_time:140905ms step_avg:97.78ms -step:1442/1695 train_time:141003ms step_avg:97.78ms -step:1443/1695 train_time:141101ms step_avg:97.78ms -step:1444/1695 train_time:141199ms step_avg:97.78ms -step:1445/1695 train_time:141297ms step_avg:97.78ms -step:1446/1695 train_time:141395ms step_avg:97.78ms -step:1447/1695 train_time:141493ms step_avg:97.78ms -step:1448/1695 train_time:141590ms step_avg:97.78ms -step:1449/1695 train_time:141687ms step_avg:97.78ms -step:1450/1695 train_time:141785ms step_avg:97.78ms -step:1451/1695 train_time:141882ms step_avg:97.78ms -step:1452/1695 train_time:141980ms step_avg:97.78ms -step:1453/1695 train_time:142078ms step_avg:97.78ms -step:1454/1695 train_time:142176ms step_avg:97.78ms -step:1455/1695 train_time:142274ms step_avg:97.78ms -step:1456/1695 train_time:142372ms step_avg:97.78ms -step:1457/1695 train_time:142470ms step_avg:97.78ms -step:1458/1695 train_time:142568ms step_avg:97.78ms -step:1459/1695 train_time:142665ms step_avg:97.78ms -step:1460/1695 train_time:142764ms step_avg:97.78ms -step:1461/1695 train_time:142863ms step_avg:97.78ms -step:1462/1695 train_time:142961ms step_avg:97.78ms -step:1463/1695 train_time:143059ms step_avg:97.78ms -step:1464/1695 train_time:143156ms step_avg:97.78ms -step:1465/1695 train_time:143253ms step_avg:97.78ms -step:1466/1695 train_time:143351ms step_avg:97.78ms -step:1467/1695 train_time:143448ms step_avg:97.78ms -step:1468/1695 train_time:143546ms step_avg:97.78ms -step:1469/1695 train_time:143644ms step_avg:97.78ms -step:1470/1695 train_time:143742ms step_avg:97.78ms -step:1471/1695 train_time:143841ms step_avg:97.78ms -step:1472/1695 train_time:143938ms step_avg:97.78ms -step:1473/1695 train_time:144036ms step_avg:97.78ms -step:1474/1695 train_time:144132ms step_avg:97.78ms -step:1475/1695 train_time:144231ms step_avg:97.78ms -step:1476/1695 train_time:144328ms step_avg:97.78ms -step:1477/1695 train_time:144426ms step_avg:97.78ms -step:1478/1695 train_time:144524ms step_avg:97.78ms -step:1479/1695 train_time:144623ms step_avg:97.78ms -step:1480/1695 train_time:144722ms step_avg:97.79ms -step:1481/1695 train_time:144821ms step_avg:97.79ms -step:1482/1695 train_time:144919ms step_avg:97.79ms -step:1483/1695 train_time:145017ms step_avg:97.79ms -step:1484/1695 train_time:145115ms step_avg:97.79ms -step:1485/1695 train_time:145212ms step_avg:97.79ms -step:1486/1695 train_time:145310ms step_avg:97.79ms -step:1487/1695 train_time:145407ms step_avg:97.79ms -step:1488/1695 train_time:145505ms step_avg:97.79ms -step:1489/1695 train_time:145603ms step_avg:97.79ms -step:1490/1695 train_time:145702ms step_avg:97.79ms -step:1491/1695 train_time:145799ms step_avg:97.79ms -step:1492/1695 train_time:145896ms step_avg:97.79ms -step:1493/1695 train_time:145994ms step_avg:97.79ms -step:1494/1695 train_time:146092ms step_avg:97.79ms -step:1495/1695 train_time:146190ms step_avg:97.79ms -step:1496/1695 train_time:146288ms step_avg:97.79ms -step:1497/1695 train_time:146385ms step_avg:97.79ms -step:1498/1695 train_time:146482ms step_avg:97.79ms -step:1499/1695 train_time:146580ms step_avg:97.79ms -step:1500/1695 train_time:146679ms step_avg:97.79ms -step:1500/1695 val_loss:3.3158 train_time:146775ms step_avg:97.85ms -step:1501/1695 train_time:146802ms step_avg:97.80ms -step:1502/1695 train_time:146885ms step_avg:97.79ms -step:1503/1695 train_time:146985ms step_avg:97.79ms -step:1504/1695 train_time:147082ms step_avg:97.79ms -step:1505/1695 train_time:147180ms step_avg:97.79ms -step:1506/1695 train_time:147276ms step_avg:97.79ms -step:1507/1695 train_time:147372ms step_avg:97.79ms -step:1508/1695 train_time:147469ms step_avg:97.79ms -step:1509/1695 train_time:147566ms step_avg:97.79ms -step:1510/1695 train_time:147663ms step_avg:97.79ms -step:1511/1695 train_time:147762ms step_avg:97.79ms -step:1512/1695 train_time:147865ms step_avg:97.79ms -step:1513/1695 train_time:147965ms step_avg:97.80ms -step:1514/1695 train_time:148064ms step_avg:97.80ms -step:1515/1695 train_time:148162ms step_avg:97.80ms -step:1516/1695 train_time:148260ms step_avg:97.80ms -step:1517/1695 train_time:148356ms step_avg:97.80ms -step:1518/1695 train_time:148454ms step_avg:97.80ms -step:1519/1695 train_time:148551ms step_avg:97.80ms -step:1520/1695 train_time:148647ms step_avg:97.79ms -step:1521/1695 train_time:148745ms step_avg:97.79ms -step:1522/1695 train_time:148844ms step_avg:97.79ms -step:1523/1695 train_time:148943ms step_avg:97.80ms -step:1524/1695 train_time:149041ms step_avg:97.80ms -step:1525/1695 train_time:149140ms step_avg:97.80ms -step:1526/1695 train_time:149238ms step_avg:97.80ms -step:1527/1695 train_time:149336ms step_avg:97.80ms -step:1528/1695 train_time:149434ms step_avg:97.80ms -step:1529/1695 train_time:149531ms step_avg:97.80ms -step:1530/1695 train_time:149628ms step_avg:97.80ms -step:1531/1695 train_time:149726ms step_avg:97.80ms -step:1532/1695 train_time:149824ms step_avg:97.80ms -step:1533/1695 train_time:149922ms step_avg:97.80ms -step:1534/1695 train_time:150020ms step_avg:97.80ms -step:1535/1695 train_time:150119ms step_avg:97.80ms -step:1536/1695 train_time:150217ms step_avg:97.80ms -step:1537/1695 train_time:150316ms step_avg:97.80ms -step:1538/1695 train_time:150413ms step_avg:97.80ms -step:1539/1695 train_time:150510ms step_avg:97.80ms -step:1540/1695 train_time:150607ms step_avg:97.80ms -step:1541/1695 train_time:150704ms step_avg:97.80ms -step:1542/1695 train_time:150803ms step_avg:97.80ms -step:1543/1695 train_time:150902ms step_avg:97.80ms -step:1544/1695 train_time:151000ms step_avg:97.80ms -step:1545/1695 train_time:151099ms step_avg:97.80ms -step:1546/1695 train_time:151198ms step_avg:97.80ms -step:1547/1695 train_time:151296ms step_avg:97.80ms -step:1548/1695 train_time:151394ms step_avg:97.80ms -step:1549/1695 train_time:151492ms step_avg:97.80ms -step:1550/1695 train_time:151590ms step_avg:97.80ms -step:1551/1695 train_time:151687ms step_avg:97.80ms -step:1552/1695 train_time:152039ms step_avg:97.96ms -step:1553/1695 train_time:152209ms step_avg:98.01ms -step:1554/1695 train_time:152305ms step_avg:98.01ms -step:1555/1695 train_time:152401ms step_avg:98.01ms -step:1556/1695 train_time:152498ms step_avg:98.01ms -step:1557/1695 train_time:152595ms step_avg:98.01ms -step:1558/1695 train_time:152692ms step_avg:98.00ms -step:1559/1695 train_time:152787ms step_avg:98.00ms -step:1560/1695 train_time:152884ms step_avg:98.00ms -step:1561/1695 train_time:152980ms step_avg:98.00ms -step:1562/1695 train_time:153085ms step_avg:98.01ms -step:1563/1695 train_time:153187ms step_avg:98.01ms -step:1564/1695 train_time:153287ms step_avg:98.01ms -step:1565/1695 train_time:153384ms step_avg:98.01ms -step:1566/1695 train_time:153482ms step_avg:98.01ms -step:1567/1695 train_time:153579ms step_avg:98.01ms -step:1568/1695 train_time:153677ms step_avg:98.01ms -step:1569/1695 train_time:153774ms step_avg:98.01ms -step:1570/1695 train_time:153871ms step_avg:98.01ms -step:1571/1695 train_time:153968ms step_avg:98.01ms -step:1572/1695 train_time:154066ms step_avg:98.01ms -step:1573/1695 train_time:154166ms step_avg:98.01ms -step:1574/1695 train_time:154265ms step_avg:98.01ms -step:1575/1695 train_time:154364ms step_avg:98.01ms -step:1576/1695 train_time:154461ms step_avg:98.01ms -step:1577/1695 train_time:154559ms step_avg:98.01ms -step:1578/1695 train_time:154656ms step_avg:98.01ms -step:1579/1695 train_time:154753ms step_avg:98.01ms -step:1580/1695 train_time:154850ms step_avg:98.01ms -step:1581/1695 train_time:154946ms step_avg:98.01ms -step:1582/1695 train_time:155044ms step_avg:98.01ms -step:1583/1695 train_time:155144ms step_avg:98.01ms -step:1584/1695 train_time:155243ms step_avg:98.01ms -step:1585/1695 train_time:155341ms step_avg:98.01ms -step:1586/1695 train_time:155439ms step_avg:98.01ms -step:1587/1695 train_time:155537ms step_avg:98.01ms -step:1588/1695 train_time:155634ms step_avg:98.01ms -step:1589/1695 train_time:155731ms step_avg:98.01ms -step:1590/1695 train_time:155829ms step_avg:98.01ms -step:1591/1695 train_time:155926ms step_avg:98.00ms -step:1592/1695 train_time:156023ms step_avg:98.00ms -step:1593/1695 train_time:156121ms step_avg:98.00ms -step:1594/1695 train_time:156220ms step_avg:98.01ms -step:1595/1695 train_time:156320ms step_avg:98.01ms -step:1596/1695 train_time:156419ms step_avg:98.01ms -step:1597/1695 train_time:156517ms step_avg:98.01ms -step:1598/1695 train_time:156615ms step_avg:98.01ms -step:1599/1695 train_time:156713ms step_avg:98.01ms -step:1600/1695 train_time:156810ms step_avg:98.01ms -step:1601/1695 train_time:156908ms step_avg:98.01ms -step:1602/1695 train_time:157005ms step_avg:98.01ms -step:1603/1695 train_time:157103ms step_avg:98.01ms -step:1604/1695 train_time:157201ms step_avg:98.01ms -step:1605/1695 train_time:157300ms step_avg:98.01ms -step:1606/1695 train_time:157399ms step_avg:98.01ms -step:1607/1695 train_time:157497ms step_avg:98.01ms -step:1608/1695 train_time:157595ms step_avg:98.01ms -step:1609/1695 train_time:157693ms step_avg:98.01ms -step:1610/1695 train_time:157791ms step_avg:98.01ms -step:1611/1695 train_time:157889ms step_avg:98.01ms -step:1612/1695 train_time:157986ms step_avg:98.01ms -step:1613/1695 train_time:158083ms step_avg:98.01ms -step:1614/1695 train_time:158181ms step_avg:98.01ms -step:1615/1695 train_time:158279ms step_avg:98.01ms -step:1616/1695 train_time:158378ms step_avg:98.01ms -step:1617/1695 train_time:158477ms step_avg:98.01ms -step:1618/1695 train_time:158575ms step_avg:98.01ms -step:1619/1695 train_time:158672ms step_avg:98.01ms -step:1620/1695 train_time:158771ms step_avg:98.01ms -step:1621/1695 train_time:158869ms step_avg:98.01ms -step:1622/1695 train_time:158967ms step_avg:98.01ms -step:1623/1695 train_time:159064ms step_avg:98.01ms -step:1624/1695 train_time:159161ms step_avg:98.01ms -step:1625/1695 train_time:159259ms step_avg:98.01ms -step:1625/1695 val_loss:3.2885 train_time:159356ms step_avg:98.07ms -step:1626/1695 train_time:159382ms step_avg:98.02ms -step:1627/1695 train_time:159464ms step_avg:98.01ms -step:1628/1695 train_time:159563ms step_avg:98.01ms -step:1629/1695 train_time:159661ms step_avg:98.01ms -step:1630/1695 train_time:159759ms step_avg:98.01ms -step:1631/1695 train_time:159856ms step_avg:98.01ms -step:1632/1695 train_time:159953ms step_avg:98.01ms -step:1633/1695 train_time:160051ms step_avg:98.01ms -step:1634/1695 train_time:160147ms step_avg:98.01ms -step:1635/1695 train_time:160244ms step_avg:98.01ms -step:1636/1695 train_time:160346ms step_avg:98.01ms -step:1637/1695 train_time:160445ms step_avg:98.01ms -step:1638/1695 train_time:160544ms step_avg:98.01ms -step:1639/1695 train_time:160643ms step_avg:98.01ms -step:1640/1695 train_time:160740ms step_avg:98.01ms -step:1641/1695 train_time:160839ms step_avg:98.01ms -step:1642/1695 train_time:160936ms step_avg:98.01ms -step:1643/1695 train_time:161034ms step_avg:98.01ms -step:1644/1695 train_time:161132ms step_avg:98.01ms -step:1645/1695 train_time:161230ms step_avg:98.01ms -step:1646/1695 train_time:161329ms step_avg:98.01ms -step:1647/1695 train_time:161427ms step_avg:98.01ms -step:1648/1695 train_time:161526ms step_avg:98.01ms -step:1649/1695 train_time:161624ms step_avg:98.01ms -step:1650/1695 train_time:161721ms step_avg:98.01ms -step:1651/1695 train_time:161819ms step_avg:98.01ms -step:1652/1695 train_time:161917ms step_avg:98.01ms -step:1653/1695 train_time:162016ms step_avg:98.01ms -step:1654/1695 train_time:162114ms step_avg:98.01ms -step:1655/1695 train_time:162213ms step_avg:98.01ms -step:1656/1695 train_time:162313ms step_avg:98.01ms -step:1657/1695 train_time:162411ms step_avg:98.02ms -step:1658/1695 train_time:162510ms step_avg:98.02ms -step:1659/1695 train_time:162608ms step_avg:98.02ms -step:1660/1695 train_time:162706ms step_avg:98.02ms -step:1661/1695 train_time:162804ms step_avg:98.02ms -step:1662/1695 train_time:162901ms step_avg:98.02ms -step:1663/1695 train_time:163000ms step_avg:98.02ms -step:1664/1695 train_time:163099ms step_avg:98.02ms -step:1665/1695 train_time:163198ms step_avg:98.02ms -step:1666/1695 train_time:163297ms step_avg:98.02ms -step:1667/1695 train_time:163396ms step_avg:98.02ms -step:1668/1695 train_time:163496ms step_avg:98.02ms -step:1669/1695 train_time:163595ms step_avg:98.02ms -step:1670/1695 train_time:163695ms step_avg:98.02ms -step:1671/1695 train_time:163793ms step_avg:98.02ms -step:1672/1695 train_time:163891ms step_avg:98.02ms -step:1673/1695 train_time:163988ms step_avg:98.02ms -step:1674/1695 train_time:164085ms step_avg:98.02ms -step:1675/1695 train_time:164182ms step_avg:98.02ms -step:1676/1695 train_time:164280ms step_avg:98.02ms -step:1677/1695 train_time:164379ms step_avg:98.02ms -step:1678/1695 train_time:164478ms step_avg:98.02ms -step:1679/1695 train_time:164578ms step_avg:98.02ms -step:1680/1695 train_time:164677ms step_avg:98.02ms -step:1681/1695 train_time:164774ms step_avg:98.02ms -step:1682/1695 train_time:164872ms step_avg:98.02ms -step:1683/1695 train_time:164969ms step_avg:98.02ms -step:1684/1695 train_time:165067ms step_avg:98.02ms -step:1685/1695 train_time:165164ms step_avg:98.02ms -step:1686/1695 train_time:165262ms step_avg:98.02ms -step:1687/1695 train_time:165359ms step_avg:98.02ms -step:1688/1695 train_time:165458ms step_avg:98.02ms -step:1689/1695 train_time:165558ms step_avg:98.02ms -step:1690/1695 train_time:165659ms step_avg:98.02ms -step:1691/1695 train_time:165757ms step_avg:98.02ms -step:1692/1695 train_time:165855ms step_avg:98.02ms -step:1693/1695 train_time:165953ms step_avg:98.02ms -step:1694/1695 train_time:166052ms step_avg:98.02ms -step:1695/1695 train_time:166151ms step_avg:98.02ms -step:1695/1695 val_loss:3.2769 train_time:166247ms step_avg:98.08ms -peak memory allocated: 34505 MiB reserved: 49576 MiB diff --git a/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt b/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt deleted file mode 100644 index 7e21a501e..000000000 --- a/records/082725_FA3/7a492532-c19b-40dd-958d-fec55aa4d3fd.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:53:12 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 32C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 34C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:508ms step_avg:507.99ms -step:2/1695 train_time:531ms step_avg:265.69ms -step:3/1695 train_time:603ms step_avg:200.91ms -step:4/1695 train_time:695ms step_avg:173.68ms -step:5/1695 train_time:787ms step_avg:157.47ms -step:6/1695 train_time:881ms step_avg:146.76ms -step:7/1695 train_time:974ms step_avg:139.13ms -step:8/1695 train_time:1067ms step_avg:133.43ms -step:9/1695 train_time:1161ms step_avg:128.95ms -step:10/1695 train_time:1253ms step_avg:125.34ms -step:11/1695 train_time:1347ms step_avg:122.48ms -step:12/1695 train_time:1445ms step_avg:120.43ms -step:13/1695 train_time:1543ms step_avg:118.70ms -step:14/1695 train_time:1638ms step_avg:117.04ms -step:15/1695 train_time:1732ms step_avg:115.49ms -step:16/1695 train_time:1827ms step_avg:114.19ms -step:17/1695 train_time:1921ms step_avg:113.01ms -step:18/1695 train_time:2015ms step_avg:111.93ms -step:19/1695 train_time:2108ms step_avg:110.95ms -step:20/1695 train_time:2202ms step_avg:110.08ms -step:21/1695 train_time:2295ms step_avg:109.29ms -step:22/1695 train_time:2390ms step_avg:108.63ms -step:23/1695 train_time:2486ms step_avg:108.07ms -step:24/1695 train_time:2583ms step_avg:107.61ms -step:25/1695 train_time:2678ms step_avg:107.10ms -step:26/1695 train_time:2771ms step_avg:106.59ms -step:27/1695 train_time:2867ms step_avg:106.18ms -step:28/1695 train_time:2962ms step_avg:105.78ms -step:29/1695 train_time:3055ms step_avg:105.34ms -step:30/1695 train_time:3149ms step_avg:104.96ms -step:31/1695 train_time:3243ms step_avg:104.60ms -step:32/1695 train_time:3336ms step_avg:104.26ms -step:33/1695 train_time:3430ms step_avg:103.94ms -step:34/1695 train_time:3526ms step_avg:103.70ms -step:35/1695 train_time:3622ms step_avg:103.48ms -step:36/1695 train_time:3717ms step_avg:103.24ms -step:37/1695 train_time:3811ms step_avg:102.99ms -step:38/1695 train_time:3906ms step_avg:102.79ms -step:39/1695 train_time:4001ms step_avg:102.58ms -step:40/1695 train_time:4094ms step_avg:102.36ms -step:41/1695 train_time:4188ms step_avg:102.14ms -step:42/1695 train_time:4282ms step_avg:101.96ms -step:43/1695 train_time:4376ms step_avg:101.76ms -step:44/1695 train_time:4470ms step_avg:101.58ms -step:45/1695 train_time:4565ms step_avg:101.45ms -step:46/1695 train_time:4661ms step_avg:101.32ms -step:47/1695 train_time:4754ms step_avg:101.15ms -step:48/1695 train_time:4849ms step_avg:101.01ms -step:49/1695 train_time:4944ms step_avg:100.89ms -step:50/1695 train_time:5039ms step_avg:100.77ms -step:51/1695 train_time:5132ms step_avg:100.63ms -step:52/1695 train_time:5227ms step_avg:100.51ms -step:53/1695 train_time:5322ms step_avg:100.41ms -step:54/1695 train_time:5416ms step_avg:100.29ms -step:55/1695 train_time:5510ms step_avg:100.18ms -step:56/1695 train_time:5606ms step_avg:100.10ms -step:57/1695 train_time:5701ms step_avg:100.02ms -step:58/1695 train_time:5795ms step_avg:99.91ms -step:59/1695 train_time:5889ms step_avg:99.81ms -step:60/1695 train_time:5984ms step_avg:99.73ms -step:61/1695 train_time:6077ms step_avg:99.63ms -step:62/1695 train_time:6171ms step_avg:99.54ms -step:63/1695 train_time:6267ms step_avg:99.47ms -step:64/1695 train_time:6362ms step_avg:99.40ms -step:65/1695 train_time:6457ms step_avg:99.33ms -step:66/1695 train_time:6551ms step_avg:99.26ms -step:67/1695 train_time:6647ms step_avg:99.21ms -step:68/1695 train_time:6743ms step_avg:99.15ms -step:69/1695 train_time:6837ms step_avg:99.08ms -step:70/1695 train_time:6930ms step_avg:99.00ms -step:71/1695 train_time:7025ms step_avg:98.95ms -step:72/1695 train_time:7119ms step_avg:98.88ms -step:73/1695 train_time:7214ms step_avg:98.82ms -step:74/1695 train_time:7308ms step_avg:98.75ms -step:75/1695 train_time:7403ms step_avg:98.70ms -step:76/1695 train_time:7497ms step_avg:98.64ms -step:77/1695 train_time:7591ms step_avg:98.59ms -step:78/1695 train_time:7685ms step_avg:98.53ms -step:79/1695 train_time:7781ms step_avg:98.50ms -step:80/1695 train_time:7876ms step_avg:98.45ms -step:81/1695 train_time:7969ms step_avg:98.39ms -step:82/1695 train_time:8065ms step_avg:98.35ms -step:83/1695 train_time:8160ms step_avg:98.31ms -step:84/1695 train_time:8253ms step_avg:98.25ms -step:85/1695 train_time:8347ms step_avg:98.21ms -step:86/1695 train_time:8442ms step_avg:98.16ms -step:87/1695 train_time:8536ms step_avg:98.11ms -step:88/1695 train_time:8630ms step_avg:98.07ms -step:89/1695 train_time:8725ms step_avg:98.03ms -step:90/1695 train_time:8819ms step_avg:97.99ms -step:91/1695 train_time:8913ms step_avg:97.95ms -step:92/1695 train_time:9008ms step_avg:97.91ms -step:93/1695 train_time:9102ms step_avg:97.87ms -step:94/1695 train_time:9196ms step_avg:97.83ms -step:95/1695 train_time:9290ms step_avg:97.79ms -step:96/1695 train_time:9385ms step_avg:97.76ms -step:97/1695 train_time:9480ms step_avg:97.73ms -step:98/1695 train_time:9574ms step_avg:97.69ms -step:99/1695 train_time:9669ms step_avg:97.66ms -step:100/1695 train_time:9764ms step_avg:97.64ms -step:101/1695 train_time:9858ms step_avg:97.61ms -step:102/1695 train_time:9952ms step_avg:97.57ms -step:103/1695 train_time:10047ms step_avg:97.54ms -step:104/1695 train_time:10141ms step_avg:97.51ms -step:105/1695 train_time:10235ms step_avg:97.47ms -step:106/1695 train_time:10329ms step_avg:97.44ms -step:107/1695 train_time:10424ms step_avg:97.42ms -step:108/1695 train_time:10519ms step_avg:97.40ms -step:109/1695 train_time:10613ms step_avg:97.37ms -step:110/1695 train_time:10708ms step_avg:97.34ms -step:111/1695 train_time:10802ms step_avg:97.32ms -step:112/1695 train_time:10896ms step_avg:97.29ms -step:113/1695 train_time:10990ms step_avg:97.25ms -step:114/1695 train_time:11084ms step_avg:97.22ms -step:115/1695 train_time:11178ms step_avg:97.20ms -step:116/1695 train_time:11271ms step_avg:97.17ms -step:117/1695 train_time:11366ms step_avg:97.14ms -step:118/1695 train_time:11460ms step_avg:97.12ms -step:119/1695 train_time:11555ms step_avg:97.10ms -step:120/1695 train_time:11649ms step_avg:97.07ms -step:121/1695 train_time:11744ms step_avg:97.05ms -step:122/1695 train_time:11839ms step_avg:97.04ms -step:123/1695 train_time:11933ms step_avg:97.01ms -step:124/1695 train_time:12028ms step_avg:97.00ms -step:125/1695 train_time:12122ms step_avg:96.98ms -step:125/1695 val_loss:4.3129 train_time:12214ms step_avg:97.71ms -step:126/1695 train_time:12238ms step_avg:97.13ms -step:127/1695 train_time:12320ms step_avg:97.01ms -step:128/1695 train_time:12421ms step_avg:97.04ms -step:129/1695 train_time:12516ms step_avg:97.02ms -step:130/1695 train_time:12609ms step_avg:97.00ms -step:131/1695 train_time:12702ms step_avg:96.96ms -step:132/1695 train_time:12795ms step_avg:96.94ms -step:133/1695 train_time:12889ms step_avg:96.91ms -step:134/1695 train_time:12982ms step_avg:96.88ms -step:135/1695 train_time:13075ms step_avg:96.85ms -step:136/1695 train_time:13168ms step_avg:96.83ms -step:137/1695 train_time:13264ms step_avg:96.82ms -step:138/1695 train_time:13359ms step_avg:96.81ms -step:139/1695 train_time:13455ms step_avg:96.80ms -step:140/1695 train_time:13550ms step_avg:96.79ms -step:141/1695 train_time:13645ms step_avg:96.77ms -step:142/1695 train_time:13739ms step_avg:96.75ms -step:143/1695 train_time:13832ms step_avg:96.73ms -step:144/1695 train_time:13926ms step_avg:96.71ms -step:145/1695 train_time:14019ms step_avg:96.68ms -step:146/1695 train_time:14112ms step_avg:96.66ms -step:147/1695 train_time:14206ms step_avg:96.64ms -step:148/1695 train_time:14302ms step_avg:96.63ms -step:149/1695 train_time:14396ms step_avg:96.62ms -step:150/1695 train_time:14492ms step_avg:96.61ms -step:151/1695 train_time:14586ms step_avg:96.60ms -step:152/1695 train_time:14681ms step_avg:96.59ms -step:153/1695 train_time:14775ms step_avg:96.57ms -step:154/1695 train_time:14869ms step_avg:96.55ms -step:155/1695 train_time:14963ms step_avg:96.53ms -step:156/1695 train_time:15056ms step_avg:96.51ms -step:157/1695 train_time:15149ms step_avg:96.49ms -step:158/1695 train_time:15245ms step_avg:96.49ms -step:159/1695 train_time:15340ms step_avg:96.48ms -step:160/1695 train_time:15434ms step_avg:96.46ms -step:161/1695 train_time:15529ms step_avg:96.45ms -step:162/1695 train_time:15624ms step_avg:96.45ms -step:163/1695 train_time:15719ms step_avg:96.44ms -step:164/1695 train_time:15813ms step_avg:96.42ms -step:165/1695 train_time:15907ms step_avg:96.40ms -step:166/1695 train_time:16001ms step_avg:96.39ms -step:167/1695 train_time:16095ms step_avg:96.38ms -step:168/1695 train_time:16189ms step_avg:96.36ms -step:169/1695 train_time:16283ms step_avg:96.35ms -step:170/1695 train_time:16378ms step_avg:96.34ms -step:171/1695 train_time:16471ms step_avg:96.32ms -step:172/1695 train_time:16567ms step_avg:96.32ms -step:173/1695 train_time:16951ms step_avg:97.98ms -step:174/1695 train_time:17020ms step_avg:97.82ms -step:175/1695 train_time:17112ms step_avg:97.78ms -step:176/1695 train_time:17206ms step_avg:97.76ms -step:177/1695 train_time:17299ms step_avg:97.73ms -step:178/1695 train_time:17391ms step_avg:97.70ms -step:179/1695 train_time:17484ms step_avg:97.68ms -step:180/1695 train_time:17578ms step_avg:97.65ms -step:181/1695 train_time:17671ms step_avg:97.63ms -step:182/1695 train_time:17764ms step_avg:97.61ms -step:183/1695 train_time:17859ms step_avg:97.59ms -step:184/1695 train_time:17956ms step_avg:97.59ms -step:185/1695 train_time:18052ms step_avg:97.58ms -step:186/1695 train_time:18148ms step_avg:97.57ms -step:187/1695 train_time:18243ms step_avg:97.56ms -step:188/1695 train_time:18336ms step_avg:97.53ms -step:189/1695 train_time:18429ms step_avg:97.51ms -step:190/1695 train_time:18522ms step_avg:97.49ms -step:191/1695 train_time:18615ms step_avg:97.46ms -step:192/1695 train_time:18709ms step_avg:97.44ms -step:193/1695 train_time:18803ms step_avg:97.42ms -step:194/1695 train_time:18898ms step_avg:97.41ms -step:195/1695 train_time:18993ms step_avg:97.40ms -step:196/1695 train_time:19088ms step_avg:97.39ms -step:197/1695 train_time:19184ms step_avg:97.38ms -step:198/1695 train_time:19278ms step_avg:97.37ms -step:199/1695 train_time:19372ms step_avg:97.35ms -step:200/1695 train_time:19466ms step_avg:97.33ms -step:201/1695 train_time:19560ms step_avg:97.31ms -step:202/1695 train_time:19652ms step_avg:97.29ms -step:203/1695 train_time:19747ms step_avg:97.28ms -step:204/1695 train_time:19841ms step_avg:97.26ms -step:205/1695 train_time:19935ms step_avg:97.24ms -step:206/1695 train_time:20030ms step_avg:97.23ms -step:207/1695 train_time:20124ms step_avg:97.22ms -step:208/1695 train_time:20219ms step_avg:97.21ms -step:209/1695 train_time:20312ms step_avg:97.19ms -step:210/1695 train_time:20406ms step_avg:97.17ms -step:211/1695 train_time:20500ms step_avg:97.16ms -step:212/1695 train_time:20593ms step_avg:97.14ms -step:213/1695 train_time:20686ms step_avg:97.12ms -step:214/1695 train_time:20781ms step_avg:97.11ms -step:215/1695 train_time:20874ms step_avg:97.09ms -step:216/1695 train_time:20969ms step_avg:97.08ms -step:217/1695 train_time:21063ms step_avg:97.07ms -step:218/1695 train_time:21158ms step_avg:97.05ms -step:219/1695 train_time:21252ms step_avg:97.04ms -step:220/1695 train_time:21348ms step_avg:97.04ms -step:221/1695 train_time:21442ms step_avg:97.02ms -step:222/1695 train_time:21535ms step_avg:97.01ms -step:223/1695 train_time:21629ms step_avg:96.99ms -step:224/1695 train_time:21723ms step_avg:96.98ms -step:225/1695 train_time:21817ms step_avg:96.97ms -step:226/1695 train_time:21911ms step_avg:96.95ms -step:227/1695 train_time:22005ms step_avg:96.94ms -step:228/1695 train_time:22099ms step_avg:96.93ms -step:229/1695 train_time:22193ms step_avg:96.91ms -step:230/1695 train_time:22287ms step_avg:96.90ms -step:231/1695 train_time:22382ms step_avg:96.89ms -step:232/1695 train_time:22476ms step_avg:96.88ms -step:233/1695 train_time:22570ms step_avg:96.87ms -step:234/1695 train_time:22663ms step_avg:96.85ms -step:235/1695 train_time:22756ms step_avg:96.83ms -step:236/1695 train_time:22850ms step_avg:96.82ms -step:237/1695 train_time:22945ms step_avg:96.81ms -step:238/1695 train_time:23040ms step_avg:96.81ms -step:239/1695 train_time:23134ms step_avg:96.79ms -step:240/1695 train_time:23228ms step_avg:96.78ms -step:241/1695 train_time:23323ms step_avg:96.78ms -step:242/1695 train_time:23418ms step_avg:96.77ms -step:243/1695 train_time:23512ms step_avg:96.76ms -step:244/1695 train_time:23606ms step_avg:96.75ms -step:245/1695 train_time:23699ms step_avg:96.73ms -step:246/1695 train_time:23792ms step_avg:96.72ms -step:247/1695 train_time:23886ms step_avg:96.71ms -step:248/1695 train_time:23981ms step_avg:96.70ms -step:249/1695 train_time:24075ms step_avg:96.69ms -step:250/1695 train_time:24169ms step_avg:96.68ms -step:250/1695 val_loss:3.9787 train_time:24262ms step_avg:97.05ms -step:251/1695 train_time:24286ms step_avg:96.76ms -step:252/1695 train_time:24364ms step_avg:96.68ms -step:253/1695 train_time:24461ms step_avg:96.68ms -step:254/1695 train_time:24555ms step_avg:96.67ms -step:255/1695 train_time:24648ms step_avg:96.66ms -step:256/1695 train_time:24742ms step_avg:96.65ms -step:257/1695 train_time:24835ms step_avg:96.63ms -step:258/1695 train_time:24928ms step_avg:96.62ms -step:259/1695 train_time:25021ms step_avg:96.61ms -step:260/1695 train_time:25114ms step_avg:96.59ms -step:261/1695 train_time:25208ms step_avg:96.58ms -step:262/1695 train_time:25304ms step_avg:96.58ms -step:263/1695 train_time:25401ms step_avg:96.58ms -step:264/1695 train_time:25495ms step_avg:96.57ms -step:265/1695 train_time:25590ms step_avg:96.57ms -step:266/1695 train_time:25684ms step_avg:96.56ms -step:267/1695 train_time:25777ms step_avg:96.54ms -step:268/1695 train_time:25870ms step_avg:96.53ms -step:269/1695 train_time:25964ms step_avg:96.52ms -step:270/1695 train_time:26058ms step_avg:96.51ms -step:271/1695 train_time:26151ms step_avg:96.50ms -step:272/1695 train_time:26245ms step_avg:96.49ms -step:273/1695 train_time:26341ms step_avg:96.49ms -step:274/1695 train_time:26436ms step_avg:96.48ms -step:275/1695 train_time:26531ms step_avg:96.48ms -step:276/1695 train_time:26626ms step_avg:96.47ms -step:277/1695 train_time:26720ms step_avg:96.46ms -step:278/1695 train_time:26813ms step_avg:96.45ms -step:279/1695 train_time:26906ms step_avg:96.44ms -step:280/1695 train_time:26999ms step_avg:96.43ms -step:281/1695 train_time:27092ms step_avg:96.41ms -step:282/1695 train_time:27187ms step_avg:96.41ms -step:283/1695 train_time:27282ms step_avg:96.40ms -step:284/1695 train_time:27376ms step_avg:96.40ms -step:285/1695 train_time:27471ms step_avg:96.39ms -step:286/1695 train_time:27566ms step_avg:96.39ms -step:287/1695 train_time:27661ms step_avg:96.38ms -step:288/1695 train_time:27754ms step_avg:96.37ms -step:289/1695 train_time:27849ms step_avg:96.36ms -step:290/1695 train_time:27943ms step_avg:96.35ms -step:291/1695 train_time:28037ms step_avg:96.35ms -step:292/1695 train_time:28130ms step_avg:96.33ms -step:293/1695 train_time:28224ms step_avg:96.33ms -step:294/1695 train_time:28318ms step_avg:96.32ms -step:295/1695 train_time:28412ms step_avg:96.31ms -step:296/1695 train_time:28507ms step_avg:96.31ms -step:297/1695 train_time:28603ms step_avg:96.30ms -step:298/1695 train_time:28697ms step_avg:96.30ms -step:299/1695 train_time:28791ms step_avg:96.29ms -step:300/1695 train_time:28885ms step_avg:96.28ms -step:301/1695 train_time:28979ms step_avg:96.27ms -step:302/1695 train_time:29072ms step_avg:96.27ms -step:303/1695 train_time:29166ms step_avg:96.26ms -step:304/1695 train_time:29261ms step_avg:96.25ms -step:305/1695 train_time:29355ms step_avg:96.25ms -step:306/1695 train_time:29449ms step_avg:96.24ms -step:307/1695 train_time:29544ms step_avg:96.23ms -step:308/1695 train_time:29638ms step_avg:96.23ms -step:309/1695 train_time:29732ms step_avg:96.22ms -step:310/1695 train_time:29827ms step_avg:96.21ms -step:311/1695 train_time:29921ms step_avg:96.21ms -step:312/1695 train_time:30014ms step_avg:96.20ms -step:313/1695 train_time:30108ms step_avg:96.19ms -step:314/1695 train_time:30203ms step_avg:96.19ms -step:315/1695 train_time:30297ms step_avg:96.18ms -step:316/1695 train_time:30391ms step_avg:96.17ms -step:317/1695 train_time:30486ms step_avg:96.17ms -step:318/1695 train_time:30580ms step_avg:96.16ms -step:319/1695 train_time:30674ms step_avg:96.16ms -step:320/1695 train_time:30769ms step_avg:96.15ms -step:321/1695 train_time:30864ms step_avg:96.15ms -step:322/1695 train_time:30957ms step_avg:96.14ms -step:323/1695 train_time:31050ms step_avg:96.13ms -step:324/1695 train_time:31145ms step_avg:96.13ms -step:325/1695 train_time:31239ms step_avg:96.12ms -step:326/1695 train_time:31333ms step_avg:96.11ms -step:327/1695 train_time:31428ms step_avg:96.11ms -step:328/1695 train_time:31523ms step_avg:96.11ms -step:329/1695 train_time:31616ms step_avg:96.10ms -step:330/1695 train_time:31710ms step_avg:96.09ms -step:331/1695 train_time:31805ms step_avg:96.09ms -step:332/1695 train_time:31899ms step_avg:96.08ms -step:333/1695 train_time:31993ms step_avg:96.07ms -step:334/1695 train_time:32087ms step_avg:96.07ms -step:335/1695 train_time:32180ms step_avg:96.06ms -step:336/1695 train_time:32274ms step_avg:96.05ms -step:337/1695 train_time:32368ms step_avg:96.05ms -step:338/1695 train_time:32468ms step_avg:96.06ms -step:339/1695 train_time:32561ms step_avg:96.05ms -step:340/1695 train_time:32655ms step_avg:96.04ms -step:341/1695 train_time:32748ms step_avg:96.04ms -step:342/1695 train_time:32839ms step_avg:96.02ms -step:343/1695 train_time:32933ms step_avg:96.01ms -step:344/1695 train_time:33027ms step_avg:96.01ms -step:345/1695 train_time:33355ms step_avg:96.68ms -step:346/1695 train_time:33456ms step_avg:96.69ms -step:347/1695 train_time:33548ms step_avg:96.68ms -step:348/1695 train_time:33642ms step_avg:96.67ms -step:349/1695 train_time:33735ms step_avg:96.66ms -step:350/1695 train_time:33828ms step_avg:96.65ms -step:351/1695 train_time:33921ms step_avg:96.64ms -step:352/1695 train_time:34014ms step_avg:96.63ms -step:353/1695 train_time:34106ms step_avg:96.62ms -step:354/1695 train_time:34200ms step_avg:96.61ms -step:355/1695 train_time:34296ms step_avg:96.61ms -step:356/1695 train_time:34393ms step_avg:96.61ms -step:357/1695 train_time:34490ms step_avg:96.61ms -step:358/1695 train_time:34584ms step_avg:96.60ms -step:359/1695 train_time:34678ms step_avg:96.60ms -step:360/1695 train_time:34771ms step_avg:96.59ms -step:361/1695 train_time:34865ms step_avg:96.58ms -step:362/1695 train_time:34959ms step_avg:96.57ms -step:363/1695 train_time:35051ms step_avg:96.56ms -step:364/1695 train_time:35145ms step_avg:96.55ms -step:365/1695 train_time:35238ms step_avg:96.54ms -step:366/1695 train_time:35332ms step_avg:96.54ms -step:367/1695 train_time:35428ms step_avg:96.53ms -step:368/1695 train_time:35523ms step_avg:96.53ms -step:369/1695 train_time:35616ms step_avg:96.52ms -step:370/1695 train_time:35710ms step_avg:96.51ms -step:371/1695 train_time:35803ms step_avg:96.51ms -step:372/1695 train_time:35897ms step_avg:96.50ms -step:373/1695 train_time:35990ms step_avg:96.49ms -step:374/1695 train_time:36083ms step_avg:96.48ms -step:375/1695 train_time:36176ms step_avg:96.47ms -step:375/1695 val_loss:3.8232 train_time:36268ms step_avg:96.71ms -step:376/1695 train_time:36292ms step_avg:96.52ms -step:377/1695 train_time:36372ms step_avg:96.48ms -step:378/1695 train_time:36470ms step_avg:96.48ms -step:379/1695 train_time:36565ms step_avg:96.48ms -step:380/1695 train_time:36659ms step_avg:96.47ms -step:381/1695 train_time:36752ms step_avg:96.46ms -step:382/1695 train_time:36845ms step_avg:96.45ms -step:383/1695 train_time:36938ms step_avg:96.44ms -step:384/1695 train_time:37031ms step_avg:96.43ms -step:385/1695 train_time:37124ms step_avg:96.43ms -step:386/1695 train_time:37218ms step_avg:96.42ms -step:387/1695 train_time:37313ms step_avg:96.42ms -step:388/1695 train_time:37410ms step_avg:96.42ms -step:389/1695 train_time:37505ms step_avg:96.41ms -step:390/1695 train_time:37599ms step_avg:96.41ms -step:391/1695 train_time:37693ms step_avg:96.40ms -step:392/1695 train_time:37787ms step_avg:96.40ms -step:393/1695 train_time:37880ms step_avg:96.39ms -step:394/1695 train_time:37973ms step_avg:96.38ms -step:395/1695 train_time:38067ms step_avg:96.37ms -step:396/1695 train_time:38160ms step_avg:96.36ms -step:397/1695 train_time:38254ms step_avg:96.36ms -step:398/1695 train_time:38350ms step_avg:96.36ms -step:399/1695 train_time:38445ms step_avg:96.35ms -step:400/1695 train_time:38540ms step_avg:96.35ms -step:401/1695 train_time:38633ms step_avg:96.34ms -step:402/1695 train_time:38727ms step_avg:96.34ms -step:403/1695 train_time:38822ms step_avg:96.33ms -step:404/1695 train_time:38915ms step_avg:96.32ms -step:405/1695 train_time:39008ms step_avg:96.32ms -step:406/1695 train_time:39101ms step_avg:96.31ms -step:407/1695 train_time:39194ms step_avg:96.30ms -step:408/1695 train_time:39289ms step_avg:96.30ms -step:409/1695 train_time:39384ms step_avg:96.29ms -step:410/1695 train_time:39478ms step_avg:96.29ms -step:411/1695 train_time:39572ms step_avg:96.28ms -step:412/1695 train_time:39667ms step_avg:96.28ms -step:413/1695 train_time:39762ms step_avg:96.28ms -step:414/1695 train_time:39856ms step_avg:96.27ms -step:415/1695 train_time:39949ms step_avg:96.26ms -step:416/1695 train_time:40043ms step_avg:96.26ms -step:417/1695 train_time:40137ms step_avg:96.25ms -step:418/1695 train_time:40230ms step_avg:96.24ms -step:419/1695 train_time:40324ms step_avg:96.24ms -step:420/1695 train_time:40419ms step_avg:96.24ms -step:421/1695 train_time:40513ms step_avg:96.23ms -step:422/1695 train_time:40607ms step_avg:96.22ms -step:423/1695 train_time:40702ms step_avg:96.22ms -step:424/1695 train_time:40795ms step_avg:96.21ms -step:425/1695 train_time:40889ms step_avg:96.21ms -step:426/1695 train_time:40984ms step_avg:96.21ms -step:427/1695 train_time:41077ms step_avg:96.20ms -step:428/1695 train_time:41170ms step_avg:96.19ms -step:429/1695 train_time:41264ms step_avg:96.19ms -step:430/1695 train_time:41359ms step_avg:96.18ms -step:431/1695 train_time:41453ms step_avg:96.18ms -step:432/1695 train_time:41547ms step_avg:96.17ms -step:433/1695 train_time:41641ms step_avg:96.17ms -step:434/1695 train_time:41734ms step_avg:96.16ms -step:435/1695 train_time:41828ms step_avg:96.16ms -step:436/1695 train_time:41922ms step_avg:96.15ms -step:437/1695 train_time:42016ms step_avg:96.15ms -step:438/1695 train_time:42110ms step_avg:96.14ms -step:439/1695 train_time:42203ms step_avg:96.13ms -step:440/1695 train_time:42297ms step_avg:96.13ms -step:441/1695 train_time:42392ms step_avg:96.13ms -step:442/1695 train_time:42487ms step_avg:96.12ms -step:443/1695 train_time:42583ms step_avg:96.12ms -step:444/1695 train_time:42676ms step_avg:96.12ms -step:445/1695 train_time:42770ms step_avg:96.11ms -step:446/1695 train_time:42864ms step_avg:96.11ms -step:447/1695 train_time:42957ms step_avg:96.10ms -step:448/1695 train_time:43050ms step_avg:96.09ms -step:449/1695 train_time:43144ms step_avg:96.09ms -step:450/1695 train_time:43237ms step_avg:96.08ms -step:451/1695 train_time:43331ms step_avg:96.08ms -step:452/1695 train_time:43426ms step_avg:96.07ms -step:453/1695 train_time:43519ms step_avg:96.07ms -step:454/1695 train_time:43613ms step_avg:96.06ms -step:455/1695 train_time:43707ms step_avg:96.06ms -step:456/1695 train_time:43801ms step_avg:96.06ms -step:457/1695 train_time:43895ms step_avg:96.05ms -step:458/1695 train_time:43990ms step_avg:96.05ms -step:459/1695 train_time:44084ms step_avg:96.04ms -step:460/1695 train_time:44178ms step_avg:96.04ms -step:461/1695 train_time:44271ms step_avg:96.03ms -step:462/1695 train_time:44365ms step_avg:96.03ms -step:463/1695 train_time:44459ms step_avg:96.02ms -step:464/1695 train_time:44553ms step_avg:96.02ms -step:465/1695 train_time:44647ms step_avg:96.02ms -step:466/1695 train_time:44742ms step_avg:96.01ms -step:467/1695 train_time:44836ms step_avg:96.01ms -step:468/1695 train_time:44930ms step_avg:96.00ms -step:469/1695 train_time:45025ms step_avg:96.00ms -step:470/1695 train_time:45119ms step_avg:96.00ms -step:471/1695 train_time:45212ms step_avg:95.99ms -step:472/1695 train_time:45306ms step_avg:95.99ms -step:473/1695 train_time:45401ms step_avg:95.98ms -step:474/1695 train_time:45494ms step_avg:95.98ms -step:475/1695 train_time:45589ms step_avg:95.98ms -step:476/1695 train_time:45683ms step_avg:95.97ms -step:477/1695 train_time:45777ms step_avg:95.97ms -step:478/1695 train_time:45870ms step_avg:95.96ms -step:479/1695 train_time:45964ms step_avg:95.96ms -step:480/1695 train_time:46059ms step_avg:95.96ms -step:481/1695 train_time:46153ms step_avg:95.95ms -step:482/1695 train_time:46247ms step_avg:95.95ms -step:483/1695 train_time:46341ms step_avg:95.94ms -step:484/1695 train_time:46435ms step_avg:95.94ms -step:485/1695 train_time:46529ms step_avg:95.94ms -step:486/1695 train_time:46623ms step_avg:95.93ms -step:487/1695 train_time:46718ms step_avg:95.93ms -step:488/1695 train_time:46811ms step_avg:95.92ms -step:489/1695 train_time:46905ms step_avg:95.92ms -step:490/1695 train_time:46998ms step_avg:95.91ms -step:491/1695 train_time:47092ms step_avg:95.91ms -step:492/1695 train_time:47186ms step_avg:95.91ms -step:493/1695 train_time:47280ms step_avg:95.90ms -step:494/1695 train_time:47374ms step_avg:95.90ms -step:495/1695 train_time:47468ms step_avg:95.89ms -step:496/1695 train_time:47562ms step_avg:95.89ms -step:497/1695 train_time:47657ms step_avg:95.89ms -step:498/1695 train_time:47751ms step_avg:95.88ms -step:499/1695 train_time:47845ms step_avg:95.88ms -step:500/1695 train_time:47938ms step_avg:95.88ms -step:500/1695 val_loss:3.7202 train_time:48030ms step_avg:96.06ms -step:501/1695 train_time:48054ms step_avg:95.92ms -step:502/1695 train_time:48133ms step_avg:95.88ms -step:503/1695 train_time:48232ms step_avg:95.89ms -step:504/1695 train_time:48327ms step_avg:95.89ms -step:505/1695 train_time:48419ms step_avg:95.88ms -step:506/1695 train_time:48513ms step_avg:95.88ms -step:507/1695 train_time:48607ms step_avg:95.87ms -step:508/1695 train_time:48699ms step_avg:95.86ms -step:509/1695 train_time:48792ms step_avg:95.86ms -step:510/1695 train_time:48885ms step_avg:95.85ms -step:511/1695 train_time:48979ms step_avg:95.85ms -step:512/1695 train_time:49076ms step_avg:95.85ms -step:513/1695 train_time:49173ms step_avg:95.85ms -step:514/1695 train_time:49268ms step_avg:95.85ms -step:515/1695 train_time:49363ms step_avg:95.85ms -step:516/1695 train_time:49456ms step_avg:95.84ms -step:517/1695 train_time:49549ms step_avg:95.84ms -step:518/1695 train_time:49643ms step_avg:95.84ms -step:519/1695 train_time:49968ms step_avg:96.28ms -step:520/1695 train_time:50168ms step_avg:96.48ms -step:521/1695 train_time:50261ms step_avg:96.47ms -step:522/1695 train_time:50353ms step_avg:96.46ms -step:523/1695 train_time:50446ms step_avg:96.46ms -step:524/1695 train_time:50539ms step_avg:96.45ms -step:525/1695 train_time:50632ms step_avg:96.44ms -step:526/1695 train_time:50725ms step_avg:96.43ms -step:527/1695 train_time:50817ms step_avg:96.43ms -step:528/1695 train_time:50910ms step_avg:96.42ms -step:529/1695 train_time:51008ms step_avg:96.42ms -step:530/1695 train_time:51106ms step_avg:96.43ms -step:531/1695 train_time:51202ms step_avg:96.43ms -step:532/1695 train_time:51296ms step_avg:96.42ms -step:533/1695 train_time:51389ms step_avg:96.41ms -step:534/1695 train_time:51482ms step_avg:96.41ms -step:535/1695 train_time:51575ms step_avg:96.40ms -step:536/1695 train_time:51668ms step_avg:96.40ms -step:537/1695 train_time:51761ms step_avg:96.39ms -step:538/1695 train_time:51854ms step_avg:96.38ms -step:539/1695 train_time:51949ms step_avg:96.38ms -step:540/1695 train_time:52044ms step_avg:96.38ms -step:541/1695 train_time:52139ms step_avg:96.38ms -step:542/1695 train_time:52234ms step_avg:96.37ms -step:543/1695 train_time:52328ms step_avg:96.37ms -step:544/1695 train_time:52421ms step_avg:96.36ms -step:545/1695 train_time:52515ms step_avg:96.36ms -step:546/1695 train_time:52609ms step_avg:96.35ms -step:547/1695 train_time:52702ms step_avg:96.35ms -step:548/1695 train_time:52795ms step_avg:96.34ms -step:549/1695 train_time:52889ms step_avg:96.34ms -step:550/1695 train_time:52983ms step_avg:96.33ms -step:551/1695 train_time:53077ms step_avg:96.33ms -step:552/1695 train_time:53172ms step_avg:96.33ms -step:553/1695 train_time:53267ms step_avg:96.32ms -step:554/1695 train_time:53361ms step_avg:96.32ms -step:555/1695 train_time:53454ms step_avg:96.31ms -step:556/1695 train_time:53548ms step_avg:96.31ms -step:557/1695 train_time:53640ms step_avg:96.30ms -step:558/1695 train_time:53734ms step_avg:96.30ms -step:559/1695 train_time:53827ms step_avg:96.29ms -step:560/1695 train_time:53920ms step_avg:96.29ms -step:561/1695 train_time:54015ms step_avg:96.28ms -step:562/1695 train_time:54111ms step_avg:96.28ms -step:563/1695 train_time:54207ms step_avg:96.28ms -step:564/1695 train_time:54301ms step_avg:96.28ms -step:565/1695 train_time:54395ms step_avg:96.27ms -step:566/1695 train_time:54488ms step_avg:96.27ms -step:567/1695 train_time:54583ms step_avg:96.27ms -step:568/1695 train_time:54678ms step_avg:96.26ms -step:569/1695 train_time:54774ms step_avg:96.26ms -step:570/1695 train_time:54869ms step_avg:96.26ms -step:571/1695 train_time:54965ms step_avg:96.26ms -step:572/1695 train_time:55061ms step_avg:96.26ms -step:573/1695 train_time:55157ms step_avg:96.26ms -step:574/1695 train_time:55253ms step_avg:96.26ms -step:575/1695 train_time:55349ms step_avg:96.26ms -step:576/1695 train_time:55445ms step_avg:96.26ms -step:577/1695 train_time:55540ms step_avg:96.26ms -step:578/1695 train_time:55636ms step_avg:96.26ms -step:579/1695 train_time:55732ms step_avg:96.26ms -step:580/1695 train_time:55828ms step_avg:96.25ms -step:581/1695 train_time:55923ms step_avg:96.25ms -step:582/1695 train_time:56018ms step_avg:96.25ms -step:583/1695 train_time:56116ms step_avg:96.25ms -step:584/1695 train_time:56213ms step_avg:96.25ms -step:585/1695 train_time:56310ms step_avg:96.26ms -step:586/1695 train_time:56407ms step_avg:96.26ms -step:587/1695 train_time:56502ms step_avg:96.26ms -step:588/1695 train_time:56597ms step_avg:96.25ms -step:589/1695 train_time:56693ms step_avg:96.25ms -step:590/1695 train_time:56789ms step_avg:96.25ms -step:591/1695 train_time:56885ms step_avg:96.25ms -step:592/1695 train_time:56981ms step_avg:96.25ms -step:593/1695 train_time:57077ms step_avg:96.25ms -step:594/1695 train_time:57173ms step_avg:96.25ms -step:595/1695 train_time:57270ms step_avg:96.25ms -step:596/1695 train_time:57365ms step_avg:96.25ms -step:597/1695 train_time:57461ms step_avg:96.25ms -step:598/1695 train_time:57556ms step_avg:96.25ms -step:599/1695 train_time:57653ms step_avg:96.25ms -step:600/1695 train_time:57749ms step_avg:96.25ms -step:601/1695 train_time:57845ms step_avg:96.25ms -step:602/1695 train_time:57941ms step_avg:96.25ms -step:603/1695 train_time:58037ms step_avg:96.25ms -step:604/1695 train_time:58134ms step_avg:96.25ms -step:605/1695 train_time:58230ms step_avg:96.25ms -step:606/1695 train_time:58326ms step_avg:96.25ms -step:607/1695 train_time:58421ms step_avg:96.25ms -step:608/1695 train_time:58518ms step_avg:96.25ms -step:609/1695 train_time:58614ms step_avg:96.25ms -step:610/1695 train_time:58711ms step_avg:96.25ms -step:611/1695 train_time:58808ms step_avg:96.25ms -step:612/1695 train_time:58904ms step_avg:96.25ms -step:613/1695 train_time:59000ms step_avg:96.25ms -step:614/1695 train_time:59095ms step_avg:96.25ms -step:615/1695 train_time:59192ms step_avg:96.25ms -step:616/1695 train_time:59289ms step_avg:96.25ms -step:617/1695 train_time:59385ms step_avg:96.25ms -step:618/1695 train_time:59481ms step_avg:96.25ms -step:619/1695 train_time:59576ms step_avg:96.25ms -step:620/1695 train_time:59673ms step_avg:96.25ms -step:621/1695 train_time:59769ms step_avg:96.25ms -step:622/1695 train_time:59865ms step_avg:96.25ms -step:623/1695 train_time:59960ms step_avg:96.24ms -step:624/1695 train_time:60056ms step_avg:96.24ms -step:625/1695 train_time:60152ms step_avg:96.24ms -step:625/1695 val_loss:3.6216 train_time:60246ms step_avg:96.39ms -step:626/1695 train_time:60272ms step_avg:96.28ms -step:627/1695 train_time:60351ms step_avg:96.25ms -step:628/1695 train_time:60448ms step_avg:96.25ms -step:629/1695 train_time:60544ms step_avg:96.25ms -step:630/1695 train_time:60638ms step_avg:96.25ms -step:631/1695 train_time:60733ms step_avg:96.25ms -step:632/1695 train_time:60828ms step_avg:96.25ms -step:633/1695 train_time:60923ms step_avg:96.25ms -step:634/1695 train_time:61018ms step_avg:96.24ms -step:635/1695 train_time:61112ms step_avg:96.24ms -step:636/1695 train_time:61210ms step_avg:96.24ms -step:637/1695 train_time:61308ms step_avg:96.25ms -step:638/1695 train_time:61406ms step_avg:96.25ms -step:639/1695 train_time:61503ms step_avg:96.25ms -step:640/1695 train_time:61598ms step_avg:96.25ms -step:641/1695 train_time:61693ms step_avg:96.24ms -step:642/1695 train_time:61788ms step_avg:96.24ms -step:643/1695 train_time:61883ms step_avg:96.24ms -step:644/1695 train_time:61977ms step_avg:96.24ms -step:645/1695 train_time:62072ms step_avg:96.24ms -step:646/1695 train_time:62168ms step_avg:96.23ms -step:647/1695 train_time:62264ms step_avg:96.24ms -step:648/1695 train_time:62361ms step_avg:96.24ms -step:649/1695 train_time:62458ms step_avg:96.24ms -step:650/1695 train_time:62554ms step_avg:96.24ms -step:651/1695 train_time:62651ms step_avg:96.24ms -step:652/1695 train_time:62747ms step_avg:96.24ms -step:653/1695 train_time:62843ms step_avg:96.24ms -step:654/1695 train_time:62937ms step_avg:96.23ms -step:655/1695 train_time:63034ms step_avg:96.23ms -step:656/1695 train_time:63130ms step_avg:96.24ms -step:657/1695 train_time:63228ms step_avg:96.24ms -step:658/1695 train_time:63324ms step_avg:96.24ms -step:659/1695 train_time:63420ms step_avg:96.24ms -step:660/1695 train_time:63515ms step_avg:96.24ms -step:661/1695 train_time:63612ms step_avg:96.24ms -step:662/1695 train_time:63707ms step_avg:96.23ms -step:663/1695 train_time:63803ms step_avg:96.23ms -step:664/1695 train_time:63897ms step_avg:96.23ms -step:665/1695 train_time:63993ms step_avg:96.23ms -step:666/1695 train_time:64089ms step_avg:96.23ms -step:667/1695 train_time:64186ms step_avg:96.23ms -step:668/1695 train_time:64282ms step_avg:96.23ms -step:669/1695 train_time:64377ms step_avg:96.23ms -step:670/1695 train_time:64473ms step_avg:96.23ms -step:671/1695 train_time:64570ms step_avg:96.23ms -step:672/1695 train_time:64666ms step_avg:96.23ms -step:673/1695 train_time:64762ms step_avg:96.23ms -step:674/1695 train_time:64857ms step_avg:96.23ms -step:675/1695 train_time:64954ms step_avg:96.23ms -step:676/1695 train_time:65048ms step_avg:96.23ms -step:677/1695 train_time:65144ms step_avg:96.22ms -step:678/1695 train_time:65238ms step_avg:96.22ms -step:679/1695 train_time:65335ms step_avg:96.22ms -step:680/1695 train_time:65431ms step_avg:96.22ms -step:681/1695 train_time:65528ms step_avg:96.22ms -step:682/1695 train_time:65624ms step_avg:96.22ms -step:683/1695 train_time:65720ms step_avg:96.22ms -step:684/1695 train_time:65815ms step_avg:96.22ms -step:685/1695 train_time:65911ms step_avg:96.22ms -step:686/1695 train_time:66007ms step_avg:96.22ms -step:687/1695 train_time:66103ms step_avg:96.22ms -step:688/1695 train_time:66199ms step_avg:96.22ms -step:689/1695 train_time:66294ms step_avg:96.22ms -step:690/1695 train_time:66390ms step_avg:96.22ms -step:691/1695 train_time:66847ms step_avg:96.74ms -step:692/1695 train_time:66917ms step_avg:96.70ms -step:693/1695 train_time:67011ms step_avg:96.70ms -step:694/1695 train_time:67106ms step_avg:96.69ms -step:695/1695 train_time:67201ms step_avg:96.69ms -step:696/1695 train_time:67296ms step_avg:96.69ms -step:697/1695 train_time:67392ms step_avg:96.69ms -step:698/1695 train_time:67487ms step_avg:96.69ms -step:699/1695 train_time:67581ms step_avg:96.68ms -step:700/1695 train_time:67676ms step_avg:96.68ms -step:701/1695 train_time:67776ms step_avg:96.69ms -step:702/1695 train_time:67879ms step_avg:96.69ms -step:703/1695 train_time:67976ms step_avg:96.69ms -step:704/1695 train_time:68073ms step_avg:96.69ms -step:705/1695 train_time:68169ms step_avg:96.69ms -step:706/1695 train_time:68266ms step_avg:96.69ms -step:707/1695 train_time:68361ms step_avg:96.69ms -step:708/1695 train_time:68455ms step_avg:96.69ms -step:709/1695 train_time:68550ms step_avg:96.69ms -step:710/1695 train_time:68645ms step_avg:96.68ms -step:711/1695 train_time:68741ms step_avg:96.68ms -step:712/1695 train_time:68838ms step_avg:96.68ms -step:713/1695 train_time:68936ms step_avg:96.68ms -step:714/1695 train_time:69034ms step_avg:96.69ms -step:715/1695 train_time:69130ms step_avg:96.69ms -step:716/1695 train_time:69226ms step_avg:96.68ms -step:717/1695 train_time:69321ms step_avg:96.68ms -step:718/1695 train_time:69416ms step_avg:96.68ms -step:719/1695 train_time:69512ms step_avg:96.68ms -step:720/1695 train_time:69607ms step_avg:96.68ms -step:721/1695 train_time:69702ms step_avg:96.67ms -step:722/1695 train_time:69799ms step_avg:96.67ms -step:723/1695 train_time:69895ms step_avg:96.67ms -step:724/1695 train_time:69992ms step_avg:96.67ms -step:725/1695 train_time:70088ms step_avg:96.67ms -step:726/1695 train_time:70185ms step_avg:96.67ms -step:727/1695 train_time:70280ms step_avg:96.67ms -step:728/1695 train_time:70375ms step_avg:96.67ms -step:729/1695 train_time:70471ms step_avg:96.67ms -step:730/1695 train_time:70566ms step_avg:96.67ms -step:731/1695 train_time:70661ms step_avg:96.66ms -step:732/1695 train_time:70757ms step_avg:96.66ms -step:733/1695 train_time:70853ms step_avg:96.66ms -step:734/1695 train_time:70951ms step_avg:96.66ms -step:735/1695 train_time:71049ms step_avg:96.66ms -step:736/1695 train_time:71145ms step_avg:96.66ms -step:737/1695 train_time:71240ms step_avg:96.66ms -step:738/1695 train_time:71336ms step_avg:96.66ms -step:739/1695 train_time:71432ms step_avg:96.66ms -step:740/1695 train_time:71528ms step_avg:96.66ms -step:741/1695 train_time:71623ms step_avg:96.66ms -step:742/1695 train_time:71719ms step_avg:96.66ms -step:743/1695 train_time:71814ms step_avg:96.65ms -step:744/1695 train_time:71910ms step_avg:96.65ms -step:745/1695 train_time:72007ms step_avg:96.65ms -step:746/1695 train_time:72103ms step_avg:96.65ms -step:747/1695 train_time:72199ms step_avg:96.65ms -step:748/1695 train_time:72295ms step_avg:96.65ms -step:749/1695 train_time:72391ms step_avg:96.65ms -step:750/1695 train_time:72488ms step_avg:96.65ms -step:750/1695 val_loss:3.5671 train_time:72581ms step_avg:96.77ms -step:751/1695 train_time:72608ms step_avg:96.68ms -step:752/1695 train_time:72687ms step_avg:96.66ms -step:753/1695 train_time:72784ms step_avg:96.66ms -step:754/1695 train_time:72880ms step_avg:96.66ms -step:755/1695 train_time:72976ms step_avg:96.66ms -step:756/1695 train_time:73072ms step_avg:96.66ms -step:757/1695 train_time:73167ms step_avg:96.65ms -step:758/1695 train_time:73262ms step_avg:96.65ms -step:759/1695 train_time:73357ms step_avg:96.65ms -step:760/1695 train_time:73452ms step_avg:96.65ms -step:761/1695 train_time:73549ms step_avg:96.65ms -step:762/1695 train_time:73646ms step_avg:96.65ms -step:763/1695 train_time:73744ms step_avg:96.65ms -step:764/1695 train_time:73841ms step_avg:96.65ms -step:765/1695 train_time:73938ms step_avg:96.65ms -step:766/1695 train_time:74034ms step_avg:96.65ms -step:767/1695 train_time:74129ms step_avg:96.65ms -step:768/1695 train_time:74224ms step_avg:96.65ms -step:769/1695 train_time:74319ms step_avg:96.64ms -step:770/1695 train_time:74414ms step_avg:96.64ms -step:771/1695 train_time:74509ms step_avg:96.64ms -step:772/1695 train_time:74606ms step_avg:96.64ms -step:773/1695 train_time:74703ms step_avg:96.64ms -step:774/1695 train_time:74800ms step_avg:96.64ms -step:775/1695 train_time:74897ms step_avg:96.64ms -step:776/1695 train_time:74994ms step_avg:96.64ms -step:777/1695 train_time:75090ms step_avg:96.64ms -step:778/1695 train_time:75185ms step_avg:96.64ms -step:779/1695 train_time:75280ms step_avg:96.64ms -step:780/1695 train_time:75375ms step_avg:96.63ms -step:781/1695 train_time:75472ms step_avg:96.63ms -step:782/1695 train_time:75567ms step_avg:96.63ms -step:783/1695 train_time:75663ms step_avg:96.63ms -step:784/1695 train_time:75760ms step_avg:96.63ms -step:785/1695 train_time:75857ms step_avg:96.63ms -step:786/1695 train_time:75953ms step_avg:96.63ms -step:787/1695 train_time:76049ms step_avg:96.63ms -step:788/1695 train_time:76144ms step_avg:96.63ms -step:789/1695 train_time:76239ms step_avg:96.63ms -step:790/1695 train_time:76335ms step_avg:96.63ms -step:791/1695 train_time:76433ms step_avg:96.63ms -step:792/1695 train_time:76529ms step_avg:96.63ms -step:793/1695 train_time:76624ms step_avg:96.63ms -step:794/1695 train_time:76721ms step_avg:96.63ms -step:795/1695 train_time:76818ms step_avg:96.63ms -step:796/1695 train_time:76916ms step_avg:96.63ms -step:797/1695 train_time:77012ms step_avg:96.63ms -step:798/1695 train_time:77108ms step_avg:96.63ms -step:799/1695 train_time:77202ms step_avg:96.62ms -step:800/1695 train_time:77298ms step_avg:96.62ms -step:801/1695 train_time:77394ms step_avg:96.62ms -step:802/1695 train_time:77490ms step_avg:96.62ms -step:803/1695 train_time:77585ms step_avg:96.62ms -step:804/1695 train_time:77681ms step_avg:96.62ms -step:805/1695 train_time:77778ms step_avg:96.62ms -step:806/1695 train_time:77874ms step_avg:96.62ms -step:807/1695 train_time:77971ms step_avg:96.62ms -step:808/1695 train_time:78067ms step_avg:96.62ms -step:809/1695 train_time:78163ms step_avg:96.62ms -step:810/1695 train_time:78258ms step_avg:96.62ms -step:811/1695 train_time:78355ms step_avg:96.61ms -step:812/1695 train_time:78450ms step_avg:96.61ms -step:813/1695 train_time:78545ms step_avg:96.61ms -step:814/1695 train_time:78641ms step_avg:96.61ms -step:815/1695 train_time:78738ms step_avg:96.61ms -step:816/1695 train_time:78835ms step_avg:96.61ms -step:817/1695 train_time:78932ms step_avg:96.61ms -step:818/1695 train_time:79028ms step_avg:96.61ms -step:819/1695 train_time:79123ms step_avg:96.61ms -step:820/1695 train_time:79219ms step_avg:96.61ms -step:821/1695 train_time:79316ms step_avg:96.61ms -step:822/1695 train_time:79412ms step_avg:96.61ms -step:823/1695 train_time:79508ms step_avg:96.61ms -step:824/1695 train_time:79604ms step_avg:96.61ms -step:825/1695 train_time:79700ms step_avg:96.61ms -step:826/1695 train_time:79796ms step_avg:96.60ms -step:827/1695 train_time:79892ms step_avg:96.61ms -step:828/1695 train_time:79989ms step_avg:96.61ms -step:829/1695 train_time:80084ms step_avg:96.60ms -step:830/1695 train_time:80180ms step_avg:96.60ms -step:831/1695 train_time:80276ms step_avg:96.60ms -step:832/1695 train_time:80373ms step_avg:96.60ms -step:833/1695 train_time:80469ms step_avg:96.60ms -step:834/1695 train_time:80565ms step_avg:96.60ms -step:835/1695 train_time:80660ms step_avg:96.60ms -step:836/1695 train_time:80756ms step_avg:96.60ms -step:837/1695 train_time:80853ms step_avg:96.60ms -step:838/1695 train_time:80949ms step_avg:96.60ms -step:839/1695 train_time:81045ms step_avg:96.60ms -step:840/1695 train_time:81141ms step_avg:96.60ms -step:841/1695 train_time:81237ms step_avg:96.60ms -step:842/1695 train_time:81333ms step_avg:96.59ms -step:843/1695 train_time:81429ms step_avg:96.59ms -step:844/1695 train_time:81524ms step_avg:96.59ms -step:845/1695 train_time:81619ms step_avg:96.59ms -step:846/1695 train_time:81716ms step_avg:96.59ms -step:847/1695 train_time:81812ms step_avg:96.59ms -step:848/1695 train_time:81908ms step_avg:96.59ms -step:849/1695 train_time:82003ms step_avg:96.59ms -step:850/1695 train_time:82099ms step_avg:96.59ms -step:851/1695 train_time:82196ms step_avg:96.59ms -step:852/1695 train_time:82292ms step_avg:96.59ms -step:853/1695 train_time:82388ms step_avg:96.59ms -step:854/1695 train_time:82484ms step_avg:96.59ms -step:855/1695 train_time:82579ms step_avg:96.58ms -step:856/1695 train_time:82676ms step_avg:96.58ms -step:857/1695 train_time:82773ms step_avg:96.58ms -step:858/1695 train_time:82869ms step_avg:96.58ms -step:859/1695 train_time:82964ms step_avg:96.58ms -step:860/1695 train_time:83060ms step_avg:96.58ms -step:861/1695 train_time:83156ms step_avg:96.58ms -step:862/1695 train_time:83252ms step_avg:96.58ms -step:863/1695 train_time:83584ms step_avg:96.85ms -step:864/1695 train_time:83778ms step_avg:96.96ms -step:865/1695 train_time:83872ms step_avg:96.96ms -step:866/1695 train_time:83966ms step_avg:96.96ms -step:867/1695 train_time:84061ms step_avg:96.96ms -step:868/1695 train_time:84156ms step_avg:96.95ms -step:869/1695 train_time:84250ms step_avg:96.95ms -step:870/1695 train_time:84345ms step_avg:96.95ms -step:871/1695 train_time:84440ms step_avg:96.95ms -step:872/1695 train_time:84535ms step_avg:96.94ms -step:873/1695 train_time:84637ms step_avg:96.95ms -step:874/1695 train_time:84737ms step_avg:96.95ms -step:875/1695 train_time:84837ms step_avg:96.96ms -step:875/1695 val_loss:3.5244 train_time:84933ms step_avg:97.07ms -step:876/1695 train_time:84957ms step_avg:96.98ms -step:877/1695 train_time:85038ms step_avg:96.96ms -step:878/1695 train_time:85138ms step_avg:96.97ms -step:879/1695 train_time:85235ms step_avg:96.97ms -step:880/1695 train_time:85330ms step_avg:96.97ms -step:881/1695 train_time:85425ms step_avg:96.96ms -step:882/1695 train_time:85520ms step_avg:96.96ms -step:883/1695 train_time:85614ms step_avg:96.96ms -step:884/1695 train_time:85709ms step_avg:96.96ms -step:885/1695 train_time:85804ms step_avg:96.95ms -step:886/1695 train_time:85901ms step_avg:96.95ms -step:887/1695 train_time:86000ms step_avg:96.96ms -step:888/1695 train_time:86100ms step_avg:96.96ms -step:889/1695 train_time:86198ms step_avg:96.96ms -step:890/1695 train_time:86295ms step_avg:96.96ms -step:891/1695 train_time:86391ms step_avg:96.96ms -step:892/1695 train_time:86486ms step_avg:96.96ms -step:893/1695 train_time:86581ms step_avg:96.96ms -step:894/1695 train_time:86677ms step_avg:96.95ms -step:895/1695 train_time:86772ms step_avg:96.95ms -step:896/1695 train_time:86868ms step_avg:96.95ms -step:897/1695 train_time:86964ms step_avg:96.95ms -step:898/1695 train_time:87062ms step_avg:96.95ms -step:899/1695 train_time:87159ms step_avg:96.95ms -step:900/1695 train_time:87256ms step_avg:96.95ms -step:901/1695 train_time:87351ms step_avg:96.95ms -step:902/1695 train_time:87447ms step_avg:96.95ms -step:903/1695 train_time:87542ms step_avg:96.95ms -step:904/1695 train_time:87638ms step_avg:96.95ms -step:905/1695 train_time:87734ms step_avg:96.94ms -step:906/1695 train_time:87829ms step_avg:96.94ms -step:907/1695 train_time:87925ms step_avg:96.94ms -step:908/1695 train_time:88021ms step_avg:96.94ms -step:909/1695 train_time:88118ms step_avg:96.94ms -step:910/1695 train_time:88216ms step_avg:96.94ms -step:911/1695 train_time:88312ms step_avg:96.94ms -step:912/1695 train_time:88407ms step_avg:96.94ms -step:913/1695 train_time:88503ms step_avg:96.94ms -step:914/1695 train_time:88599ms step_avg:96.94ms -step:915/1695 train_time:88695ms step_avg:96.93ms -step:916/1695 train_time:88791ms step_avg:96.93ms -step:917/1695 train_time:88887ms step_avg:96.93ms -step:918/1695 train_time:88983ms step_avg:96.93ms -step:919/1695 train_time:89080ms step_avg:96.93ms -step:920/1695 train_time:89177ms step_avg:96.93ms -step:921/1695 train_time:89274ms step_avg:96.93ms -step:922/1695 train_time:89370ms step_avg:96.93ms -step:923/1695 train_time:89467ms step_avg:96.93ms -step:924/1695 train_time:89562ms step_avg:96.93ms -step:925/1695 train_time:89658ms step_avg:96.93ms -step:926/1695 train_time:89752ms step_avg:96.92ms -step:927/1695 train_time:89848ms step_avg:96.92ms -step:928/1695 train_time:89944ms step_avg:96.92ms -step:929/1695 train_time:90041ms step_avg:96.92ms -step:930/1695 train_time:90138ms step_avg:96.92ms -step:931/1695 train_time:90235ms step_avg:96.92ms -step:932/1695 train_time:90332ms step_avg:96.92ms -step:933/1695 train_time:90428ms step_avg:96.92ms -step:934/1695 train_time:90523ms step_avg:96.92ms -step:935/1695 train_time:90620ms step_avg:96.92ms -step:936/1695 train_time:90717ms step_avg:96.92ms -step:937/1695 train_time:90813ms step_avg:96.92ms -step:938/1695 train_time:90908ms step_avg:96.92ms -step:939/1695 train_time:91003ms step_avg:96.92ms -step:940/1695 train_time:91099ms step_avg:96.91ms -step:941/1695 train_time:91195ms step_avg:96.91ms -step:942/1695 train_time:91291ms step_avg:96.91ms -step:943/1695 train_time:91386ms step_avg:96.91ms -step:944/1695 train_time:91482ms step_avg:96.91ms -step:945/1695 train_time:91578ms step_avg:96.91ms -step:946/1695 train_time:91675ms step_avg:96.91ms -step:947/1695 train_time:91771ms step_avg:96.91ms -step:948/1695 train_time:91866ms step_avg:96.91ms -step:949/1695 train_time:91962ms step_avg:96.90ms -step:950/1695 train_time:92057ms step_avg:96.90ms -step:951/1695 train_time:92153ms step_avg:96.90ms -step:952/1695 train_time:92249ms step_avg:96.90ms -step:953/1695 train_time:92345ms step_avg:96.90ms -step:954/1695 train_time:92441ms step_avg:96.90ms -step:955/1695 train_time:92538ms step_avg:96.90ms -step:956/1695 train_time:92636ms step_avg:96.90ms -step:957/1695 train_time:92732ms step_avg:96.90ms -step:958/1695 train_time:92827ms step_avg:96.90ms -step:959/1695 train_time:92923ms step_avg:96.90ms -step:960/1695 train_time:93018ms step_avg:96.89ms -step:961/1695 train_time:93114ms step_avg:96.89ms -step:962/1695 train_time:93210ms step_avg:96.89ms -step:963/1695 train_time:93305ms step_avg:96.89ms -step:964/1695 train_time:93401ms step_avg:96.89ms -step:965/1695 train_time:93496ms step_avg:96.89ms -step:966/1695 train_time:93592ms step_avg:96.89ms -step:967/1695 train_time:93688ms step_avg:96.88ms -step:968/1695 train_time:93783ms step_avg:96.88ms -step:969/1695 train_time:93880ms step_avg:96.88ms -step:970/1695 train_time:93976ms step_avg:96.88ms -step:971/1695 train_time:94073ms step_avg:96.88ms -step:972/1695 train_time:94168ms step_avg:96.88ms -step:973/1695 train_time:94263ms step_avg:96.88ms -step:974/1695 train_time:94360ms step_avg:96.88ms -step:975/1695 train_time:94456ms step_avg:96.88ms -step:976/1695 train_time:94553ms step_avg:96.88ms -step:977/1695 train_time:94648ms step_avg:96.88ms -step:978/1695 train_time:94743ms step_avg:96.87ms -step:979/1695 train_time:94840ms step_avg:96.87ms -step:980/1695 train_time:94937ms step_avg:96.87ms -step:981/1695 train_time:95033ms step_avg:96.87ms -step:982/1695 train_time:95129ms step_avg:96.87ms -step:983/1695 train_time:95224ms step_avg:96.87ms -step:984/1695 train_time:95320ms step_avg:96.87ms -step:985/1695 train_time:95416ms step_avg:96.87ms -step:986/1695 train_time:95512ms step_avg:96.87ms -step:987/1695 train_time:95608ms step_avg:96.87ms -step:988/1695 train_time:95703ms step_avg:96.87ms -step:989/1695 train_time:95799ms step_avg:96.86ms -step:990/1695 train_time:95895ms step_avg:96.86ms -step:991/1695 train_time:95991ms step_avg:96.86ms -step:992/1695 train_time:96087ms step_avg:96.86ms -step:993/1695 train_time:96183ms step_avg:96.86ms -step:994/1695 train_time:96280ms step_avg:96.86ms -step:995/1695 train_time:96375ms step_avg:96.86ms -step:996/1695 train_time:96471ms step_avg:96.86ms -step:997/1695 train_time:96566ms step_avg:96.86ms -step:998/1695 train_time:96663ms step_avg:96.86ms -step:999/1695 train_time:96758ms step_avg:96.85ms -step:1000/1695 train_time:96854ms step_avg:96.85ms -step:1000/1695 val_loss:3.4839 train_time:96948ms step_avg:96.95ms -step:1001/1695 train_time:96972ms step_avg:96.88ms -step:1002/1695 train_time:97055ms step_avg:96.86ms -step:1003/1695 train_time:97153ms step_avg:96.86ms -step:1004/1695 train_time:97250ms step_avg:96.86ms -step:1005/1695 train_time:97345ms step_avg:96.86ms -step:1006/1695 train_time:97440ms step_avg:96.86ms -step:1007/1695 train_time:97534ms step_avg:96.86ms -step:1008/1695 train_time:97629ms step_avg:96.85ms -step:1009/1695 train_time:97724ms step_avg:96.85ms -step:1010/1695 train_time:97819ms step_avg:96.85ms -step:1011/1695 train_time:97914ms step_avg:96.85ms -step:1012/1695 train_time:98013ms step_avg:96.85ms -step:1013/1695 train_time:98112ms step_avg:96.85ms -step:1014/1695 train_time:98211ms step_avg:96.86ms -step:1015/1695 train_time:98309ms step_avg:96.86ms -step:1016/1695 train_time:98406ms step_avg:96.86ms -step:1017/1695 train_time:98501ms step_avg:96.85ms -step:1018/1695 train_time:98596ms step_avg:96.85ms -step:1019/1695 train_time:98690ms step_avg:96.85ms -step:1020/1695 train_time:98786ms step_avg:96.85ms -step:1021/1695 train_time:98883ms step_avg:96.85ms -step:1022/1695 train_time:98979ms step_avg:96.85ms -step:1023/1695 train_time:99077ms step_avg:96.85ms -step:1024/1695 train_time:99173ms step_avg:96.85ms -step:1025/1695 train_time:99270ms step_avg:96.85ms -step:1026/1695 train_time:99366ms step_avg:96.85ms -step:1027/1695 train_time:99462ms step_avg:96.85ms -step:1028/1695 train_time:99556ms step_avg:96.84ms -step:1029/1695 train_time:99651ms step_avg:96.84ms -step:1030/1695 train_time:99746ms step_avg:96.84ms -step:1031/1695 train_time:99842ms step_avg:96.84ms -step:1032/1695 train_time:99938ms step_avg:96.84ms -step:1033/1695 train_time:100035ms step_avg:96.84ms -step:1034/1695 train_time:100131ms step_avg:96.84ms -step:1035/1695 train_time:100228ms step_avg:96.84ms -step:1036/1695 train_time:100552ms step_avg:97.06ms -step:1037/1695 train_time:100740ms step_avg:97.15ms -step:1038/1695 train_time:100833ms step_avg:97.14ms -step:1039/1695 train_time:100928ms step_avg:97.14ms -step:1040/1695 train_time:101024ms step_avg:97.14ms -step:1041/1695 train_time:101118ms step_avg:97.14ms -step:1042/1695 train_time:101212ms step_avg:97.13ms -step:1043/1695 train_time:101307ms step_avg:97.13ms -step:1044/1695 train_time:101402ms step_avg:97.13ms -step:1045/1695 train_time:101497ms step_avg:97.13ms -step:1046/1695 train_time:101598ms step_avg:97.13ms -step:1047/1695 train_time:101696ms step_avg:97.13ms -step:1048/1695 train_time:101794ms step_avg:97.13ms -step:1049/1695 train_time:101890ms step_avg:97.13ms -step:1050/1695 train_time:101986ms step_avg:97.13ms -step:1051/1695 train_time:102083ms step_avg:97.13ms -step:1052/1695 train_time:102178ms step_avg:97.13ms -step:1053/1695 train_time:102272ms step_avg:97.12ms -step:1054/1695 train_time:102367ms step_avg:97.12ms -step:1055/1695 train_time:102462ms step_avg:97.12ms -step:1056/1695 train_time:102560ms step_avg:97.12ms -step:1057/1695 train_time:102656ms step_avg:97.12ms -step:1058/1695 train_time:102753ms step_avg:97.12ms -step:1059/1695 train_time:102850ms step_avg:97.12ms -step:1060/1695 train_time:102947ms step_avg:97.12ms -step:1061/1695 train_time:103044ms step_avg:97.12ms -step:1062/1695 train_time:103140ms step_avg:97.12ms -step:1063/1695 train_time:103235ms step_avg:97.12ms -step:1064/1695 train_time:103330ms step_avg:97.11ms -step:1065/1695 train_time:103425ms step_avg:97.11ms -step:1066/1695 train_time:103521ms step_avg:97.11ms -step:1067/1695 train_time:103617ms step_avg:97.11ms -step:1068/1695 train_time:103714ms step_avg:97.11ms -step:1069/1695 train_time:103810ms step_avg:97.11ms -step:1070/1695 train_time:103907ms step_avg:97.11ms -step:1071/1695 train_time:104004ms step_avg:97.11ms -step:1072/1695 train_time:104100ms step_avg:97.11ms -step:1073/1695 train_time:104195ms step_avg:97.11ms -step:1074/1695 train_time:104289ms step_avg:97.10ms -step:1075/1695 train_time:104385ms step_avg:97.10ms -step:1076/1695 train_time:104481ms step_avg:97.10ms -step:1077/1695 train_time:104578ms step_avg:97.10ms -step:1078/1695 train_time:104673ms step_avg:97.10ms -step:1079/1695 train_time:104769ms step_avg:97.10ms -step:1080/1695 train_time:104865ms step_avg:97.10ms -step:1081/1695 train_time:104961ms step_avg:97.10ms -step:1082/1695 train_time:105056ms step_avg:97.09ms -step:1083/1695 train_time:105152ms step_avg:97.09ms -step:1084/1695 train_time:105248ms step_avg:97.09ms -step:1085/1695 train_time:105344ms step_avg:97.09ms -step:1086/1695 train_time:105439ms step_avg:97.09ms -step:1087/1695 train_time:105535ms step_avg:97.09ms -step:1088/1695 train_time:105631ms step_avg:97.09ms -step:1089/1695 train_time:105729ms step_avg:97.09ms -step:1090/1695 train_time:105825ms step_avg:97.09ms -step:1091/1695 train_time:105922ms step_avg:97.09ms -step:1092/1695 train_time:106018ms step_avg:97.09ms -step:1093/1695 train_time:106114ms step_avg:97.08ms -step:1094/1695 train_time:106209ms step_avg:97.08ms -step:1095/1695 train_time:106306ms step_avg:97.08ms -step:1096/1695 train_time:106402ms step_avg:97.08ms -step:1097/1695 train_time:106497ms step_avg:97.08ms -step:1098/1695 train_time:106592ms step_avg:97.08ms -step:1099/1695 train_time:106688ms step_avg:97.08ms -step:1100/1695 train_time:106785ms step_avg:97.08ms -step:1101/1695 train_time:106882ms step_avg:97.08ms -step:1102/1695 train_time:106978ms step_avg:97.08ms -step:1103/1695 train_time:107073ms step_avg:97.07ms -step:1104/1695 train_time:107169ms step_avg:97.07ms -step:1105/1695 train_time:107265ms step_avg:97.07ms -step:1106/1695 train_time:107362ms step_avg:97.07ms -step:1107/1695 train_time:107457ms step_avg:97.07ms -step:1108/1695 train_time:107552ms step_avg:97.07ms -step:1109/1695 train_time:107648ms step_avg:97.07ms -step:1110/1695 train_time:107745ms step_avg:97.07ms -step:1111/1695 train_time:107841ms step_avg:97.07ms -step:1112/1695 train_time:107938ms step_avg:97.07ms -step:1113/1695 train_time:108034ms step_avg:97.07ms -step:1114/1695 train_time:108130ms step_avg:97.06ms -step:1115/1695 train_time:108227ms step_avg:97.06ms -step:1116/1695 train_time:108323ms step_avg:97.06ms -step:1117/1695 train_time:108419ms step_avg:97.06ms -step:1118/1695 train_time:108515ms step_avg:97.06ms -step:1119/1695 train_time:108611ms step_avg:97.06ms -step:1120/1695 train_time:108708ms step_avg:97.06ms -step:1121/1695 train_time:108804ms step_avg:97.06ms -step:1122/1695 train_time:108901ms step_avg:97.06ms -step:1123/1695 train_time:108995ms step_avg:97.06ms -step:1124/1695 train_time:109090ms step_avg:97.06ms -step:1125/1695 train_time:109187ms step_avg:97.06ms -step:1125/1695 val_loss:3.4370 train_time:109281ms step_avg:97.14ms -step:1126/1695 train_time:109306ms step_avg:97.08ms -step:1127/1695 train_time:109389ms step_avg:97.06ms -step:1128/1695 train_time:109486ms step_avg:97.06ms -step:1129/1695 train_time:109581ms step_avg:97.06ms -step:1130/1695 train_time:109676ms step_avg:97.06ms -step:1131/1695 train_time:109771ms step_avg:97.06ms -step:1132/1695 train_time:109865ms step_avg:97.05ms -step:1133/1695 train_time:109961ms step_avg:97.05ms -step:1134/1695 train_time:110058ms step_avg:97.05ms -step:1135/1695 train_time:110154ms step_avg:97.05ms -step:1136/1695 train_time:110253ms step_avg:97.05ms -step:1137/1695 train_time:110354ms step_avg:97.06ms -step:1138/1695 train_time:110454ms step_avg:97.06ms -step:1139/1695 train_time:110552ms step_avg:97.06ms -step:1140/1695 train_time:110650ms step_avg:97.06ms -step:1141/1695 train_time:110746ms step_avg:97.06ms -step:1142/1695 train_time:110842ms step_avg:97.06ms -step:1143/1695 train_time:110939ms step_avg:97.06ms -step:1144/1695 train_time:111035ms step_avg:97.06ms -step:1145/1695 train_time:111133ms step_avg:97.06ms -step:1146/1695 train_time:111230ms step_avg:97.06ms -step:1147/1695 train_time:111329ms step_avg:97.06ms -step:1148/1695 train_time:111428ms step_avg:97.06ms -step:1149/1695 train_time:111526ms step_avg:97.06ms -step:1150/1695 train_time:111624ms step_avg:97.06ms -step:1151/1695 train_time:111721ms step_avg:97.06ms -step:1152/1695 train_time:111819ms step_avg:97.06ms -step:1153/1695 train_time:111916ms step_avg:97.06ms -step:1154/1695 train_time:112013ms step_avg:97.06ms -step:1155/1695 train_time:112110ms step_avg:97.06ms -step:1156/1695 train_time:112207ms step_avg:97.07ms -step:1157/1695 train_time:112305ms step_avg:97.07ms -step:1158/1695 train_time:112403ms step_avg:97.07ms -step:1159/1695 train_time:112502ms step_avg:97.07ms -step:1160/1695 train_time:112601ms step_avg:97.07ms -step:1161/1695 train_time:112699ms step_avg:97.07ms -step:1162/1695 train_time:112796ms step_avg:97.07ms -step:1163/1695 train_time:112894ms step_avg:97.07ms -step:1164/1695 train_time:112991ms step_avg:97.07ms -step:1165/1695 train_time:113088ms step_avg:97.07ms -step:1166/1695 train_time:113185ms step_avg:97.07ms -step:1167/1695 train_time:113282ms step_avg:97.07ms -step:1168/1695 train_time:113380ms step_avg:97.07ms -step:1169/1695 train_time:113479ms step_avg:97.07ms -step:1170/1695 train_time:113579ms step_avg:97.08ms -step:1171/1695 train_time:113676ms step_avg:97.08ms -step:1172/1695 train_time:113774ms step_avg:97.08ms -step:1173/1695 train_time:113870ms step_avg:97.08ms -step:1174/1695 train_time:113966ms step_avg:97.08ms -step:1175/1695 train_time:114063ms step_avg:97.07ms -step:1176/1695 train_time:114161ms step_avg:97.08ms -step:1177/1695 train_time:114259ms step_avg:97.08ms -step:1178/1695 train_time:114357ms step_avg:97.08ms -step:1179/1695 train_time:114455ms step_avg:97.08ms -step:1180/1695 train_time:114553ms step_avg:97.08ms -step:1181/1695 train_time:114651ms step_avg:97.08ms -step:1182/1695 train_time:114748ms step_avg:97.08ms -step:1183/1695 train_time:114845ms step_avg:97.08ms -step:1184/1695 train_time:114942ms step_avg:97.08ms -step:1185/1695 train_time:115041ms step_avg:97.08ms -step:1186/1695 train_time:115140ms step_avg:97.08ms -step:1187/1695 train_time:115238ms step_avg:97.08ms -step:1188/1695 train_time:115337ms step_avg:97.09ms -step:1189/1695 train_time:115435ms step_avg:97.09ms -step:1190/1695 train_time:115533ms step_avg:97.09ms -step:1191/1695 train_time:115632ms step_avg:97.09ms -step:1192/1695 train_time:115730ms step_avg:97.09ms -step:1193/1695 train_time:115826ms step_avg:97.09ms -step:1194/1695 train_time:115923ms step_avg:97.09ms -step:1195/1695 train_time:116021ms step_avg:97.09ms -step:1196/1695 train_time:116119ms step_avg:97.09ms -step:1197/1695 train_time:116216ms step_avg:97.09ms -step:1198/1695 train_time:116313ms step_avg:97.09ms -step:1199/1695 train_time:116411ms step_avg:97.09ms -step:1200/1695 train_time:116508ms step_avg:97.09ms -step:1201/1695 train_time:116605ms step_avg:97.09ms -step:1202/1695 train_time:116704ms step_avg:97.09ms -step:1203/1695 train_time:116801ms step_avg:97.09ms -step:1204/1695 train_time:116899ms step_avg:97.09ms -step:1205/1695 train_time:116997ms step_avg:97.09ms -step:1206/1695 train_time:117095ms step_avg:97.09ms -step:1207/1695 train_time:117193ms step_avg:97.09ms -step:1208/1695 train_time:117515ms step_avg:97.28ms -step:1209/1695 train_time:117719ms step_avg:97.37ms -step:1210/1695 train_time:117814ms step_avg:97.37ms -step:1211/1695 train_time:117911ms step_avg:97.37ms -step:1212/1695 train_time:118008ms step_avg:97.37ms -step:1213/1695 train_time:118103ms step_avg:97.36ms -step:1214/1695 train_time:118200ms step_avg:97.36ms -step:1215/1695 train_time:118297ms step_avg:97.36ms -step:1216/1695 train_time:118393ms step_avg:97.36ms -step:1217/1695 train_time:118491ms step_avg:97.36ms -step:1218/1695 train_time:118592ms step_avg:97.37ms -step:1219/1695 train_time:118695ms step_avg:97.37ms -step:1220/1695 train_time:118794ms step_avg:97.37ms -step:1221/1695 train_time:118893ms step_avg:97.37ms -step:1222/1695 train_time:118990ms step_avg:97.37ms -step:1223/1695 train_time:119087ms step_avg:97.37ms -step:1224/1695 train_time:119183ms step_avg:97.37ms -step:1225/1695 train_time:119281ms step_avg:97.37ms -step:1226/1695 train_time:119378ms step_avg:97.37ms -step:1227/1695 train_time:119475ms step_avg:97.37ms -step:1228/1695 train_time:119573ms step_avg:97.37ms -step:1229/1695 train_time:119673ms step_avg:97.37ms -step:1230/1695 train_time:119772ms step_avg:97.38ms -step:1231/1695 train_time:119870ms step_avg:97.38ms -step:1232/1695 train_time:119966ms step_avg:97.38ms -step:1233/1695 train_time:120064ms step_avg:97.38ms -step:1234/1695 train_time:120162ms step_avg:97.38ms -step:1235/1695 train_time:120259ms step_avg:97.38ms -step:1236/1695 train_time:120357ms step_avg:97.38ms -step:1237/1695 train_time:120454ms step_avg:97.38ms -step:1238/1695 train_time:120552ms step_avg:97.38ms -step:1239/1695 train_time:120650ms step_avg:97.38ms -step:1240/1695 train_time:120748ms step_avg:97.38ms -step:1241/1695 train_time:120845ms step_avg:97.38ms -step:1242/1695 train_time:120944ms step_avg:97.38ms -step:1243/1695 train_time:121041ms step_avg:97.38ms -step:1244/1695 train_time:121139ms step_avg:97.38ms -step:1245/1695 train_time:121235ms step_avg:97.38ms -step:1246/1695 train_time:121332ms step_avg:97.38ms -step:1247/1695 train_time:121429ms step_avg:97.38ms -step:1248/1695 train_time:121527ms step_avg:97.38ms -step:1249/1695 train_time:121625ms step_avg:97.38ms -step:1250/1695 train_time:121724ms step_avg:97.38ms -step:1250/1695 val_loss:3.3885 train_time:121819ms step_avg:97.46ms -step:1251/1695 train_time:121843ms step_avg:97.40ms -step:1252/1695 train_time:121929ms step_avg:97.39ms -step:1253/1695 train_time:122027ms step_avg:97.39ms -step:1254/1695 train_time:122123ms step_avg:97.39ms -step:1255/1695 train_time:122220ms step_avg:97.39ms -step:1256/1695 train_time:122317ms step_avg:97.39ms -step:1257/1695 train_time:122414ms step_avg:97.39ms -step:1258/1695 train_time:122510ms step_avg:97.38ms -step:1259/1695 train_time:122606ms step_avg:97.38ms -step:1260/1695 train_time:122703ms step_avg:97.38ms -step:1261/1695 train_time:122803ms step_avg:97.39ms -step:1262/1695 train_time:122905ms step_avg:97.39ms -step:1263/1695 train_time:123004ms step_avg:97.39ms -step:1264/1695 train_time:123101ms step_avg:97.39ms -step:1265/1695 train_time:123199ms step_avg:97.39ms -step:1266/1695 train_time:123296ms step_avg:97.39ms -step:1267/1695 train_time:123393ms step_avg:97.39ms -step:1268/1695 train_time:123490ms step_avg:97.39ms -step:1269/1695 train_time:123587ms step_avg:97.39ms -step:1270/1695 train_time:123686ms step_avg:97.39ms -step:1271/1695 train_time:123781ms step_avg:97.39ms -step:1272/1695 train_time:123881ms step_avg:97.39ms -step:1273/1695 train_time:123982ms step_avg:97.39ms -step:1274/1695 train_time:124082ms step_avg:97.40ms -step:1275/1695 train_time:124180ms step_avg:97.40ms -step:1276/1695 train_time:124278ms step_avg:97.40ms -step:1277/1695 train_time:124376ms step_avg:97.40ms -step:1278/1695 train_time:124474ms step_avg:97.40ms -step:1279/1695 train_time:124570ms step_avg:97.40ms -step:1280/1695 train_time:124666ms step_avg:97.40ms -step:1281/1695 train_time:124763ms step_avg:97.40ms -step:1282/1695 train_time:124862ms step_avg:97.40ms -step:1283/1695 train_time:124961ms step_avg:97.40ms -step:1284/1695 train_time:125061ms step_avg:97.40ms -step:1285/1695 train_time:125160ms step_avg:97.40ms -step:1286/1695 train_time:125258ms step_avg:97.40ms -step:1287/1695 train_time:125356ms step_avg:97.40ms -step:1288/1695 train_time:125454ms step_avg:97.40ms -step:1289/1695 train_time:125552ms step_avg:97.40ms -step:1290/1695 train_time:125649ms step_avg:97.40ms -step:1291/1695 train_time:125746ms step_avg:97.40ms -step:1292/1695 train_time:125843ms step_avg:97.40ms -step:1293/1695 train_time:125941ms step_avg:97.40ms -step:1294/1695 train_time:126039ms step_avg:97.40ms -step:1295/1695 train_time:126138ms step_avg:97.40ms -step:1296/1695 train_time:126237ms step_avg:97.40ms -step:1297/1695 train_time:126335ms step_avg:97.41ms -step:1298/1695 train_time:126432ms step_avg:97.41ms -step:1299/1695 train_time:126529ms step_avg:97.41ms -step:1300/1695 train_time:126626ms step_avg:97.40ms -step:1301/1695 train_time:126723ms step_avg:97.40ms -step:1302/1695 train_time:126821ms step_avg:97.40ms -step:1303/1695 train_time:126919ms step_avg:97.41ms -step:1304/1695 train_time:127018ms step_avg:97.41ms -step:1305/1695 train_time:127116ms step_avg:97.41ms -step:1306/1695 train_time:127214ms step_avg:97.41ms -step:1307/1695 train_time:127312ms step_avg:97.41ms -step:1308/1695 train_time:127410ms step_avg:97.41ms -step:1309/1695 train_time:127508ms step_avg:97.41ms -step:1310/1695 train_time:127605ms step_avg:97.41ms -step:1311/1695 train_time:127703ms step_avg:97.41ms -step:1312/1695 train_time:127800ms step_avg:97.41ms -step:1313/1695 train_time:127898ms step_avg:97.41ms -step:1314/1695 train_time:127996ms step_avg:97.41ms -step:1315/1695 train_time:128093ms step_avg:97.41ms -step:1316/1695 train_time:128191ms step_avg:97.41ms -step:1317/1695 train_time:128289ms step_avg:97.41ms -step:1318/1695 train_time:128386ms step_avg:97.41ms -step:1319/1695 train_time:128483ms step_avg:97.41ms -step:1320/1695 train_time:128582ms step_avg:97.41ms -step:1321/1695 train_time:128679ms step_avg:97.41ms -step:1322/1695 train_time:128777ms step_avg:97.41ms -step:1323/1695 train_time:128875ms step_avg:97.41ms -step:1324/1695 train_time:128973ms step_avg:97.41ms -step:1325/1695 train_time:129070ms step_avg:97.41ms -step:1326/1695 train_time:129168ms step_avg:97.41ms -step:1327/1695 train_time:129265ms step_avg:97.41ms -step:1328/1695 train_time:129363ms step_avg:97.41ms -step:1329/1695 train_time:129461ms step_avg:97.41ms -step:1330/1695 train_time:129559ms step_avg:97.41ms -step:1331/1695 train_time:129658ms step_avg:97.41ms -step:1332/1695 train_time:129756ms step_avg:97.41ms -step:1333/1695 train_time:129854ms step_avg:97.42ms -step:1334/1695 train_time:129952ms step_avg:97.42ms -step:1335/1695 train_time:130049ms step_avg:97.41ms -step:1336/1695 train_time:130146ms step_avg:97.41ms -step:1337/1695 train_time:130244ms step_avg:97.42ms -step:1338/1695 train_time:130342ms step_avg:97.42ms -step:1339/1695 train_time:130440ms step_avg:97.42ms -step:1340/1695 train_time:130539ms step_avg:97.42ms -step:1341/1695 train_time:130637ms step_avg:97.42ms -step:1342/1695 train_time:130735ms step_avg:97.42ms -step:1343/1695 train_time:130832ms step_avg:97.42ms -step:1344/1695 train_time:130929ms step_avg:97.42ms -step:1345/1695 train_time:131026ms step_avg:97.42ms -step:1346/1695 train_time:131124ms step_avg:97.42ms -step:1347/1695 train_time:131220ms step_avg:97.42ms -step:1348/1695 train_time:131318ms step_avg:97.42ms -step:1349/1695 train_time:131417ms step_avg:97.42ms -step:1350/1695 train_time:131515ms step_avg:97.42ms -step:1351/1695 train_time:131613ms step_avg:97.42ms -step:1352/1695 train_time:131710ms step_avg:97.42ms -step:1353/1695 train_time:131808ms step_avg:97.42ms -step:1354/1695 train_time:131905ms step_avg:97.42ms -step:1355/1695 train_time:132003ms step_avg:97.42ms -step:1356/1695 train_time:132102ms step_avg:97.42ms -step:1357/1695 train_time:132200ms step_avg:97.42ms -step:1358/1695 train_time:132297ms step_avg:97.42ms -step:1359/1695 train_time:132395ms step_avg:97.42ms -step:1360/1695 train_time:132492ms step_avg:97.42ms -step:1361/1695 train_time:132590ms step_avg:97.42ms -step:1362/1695 train_time:132688ms step_avg:97.42ms -step:1363/1695 train_time:132786ms step_avg:97.42ms -step:1364/1695 train_time:132883ms step_avg:97.42ms -step:1365/1695 train_time:132981ms step_avg:97.42ms -step:1366/1695 train_time:133080ms step_avg:97.42ms -step:1367/1695 train_time:133179ms step_avg:97.42ms -step:1368/1695 train_time:133277ms step_avg:97.42ms -step:1369/1695 train_time:133375ms step_avg:97.43ms -step:1370/1695 train_time:133473ms step_avg:97.43ms -step:1371/1695 train_time:133571ms step_avg:97.43ms -step:1372/1695 train_time:133669ms step_avg:97.43ms -step:1373/1695 train_time:133766ms step_avg:97.43ms -step:1374/1695 train_time:133863ms step_avg:97.43ms -step:1375/1695 train_time:133961ms step_avg:97.43ms -step:1375/1695 val_loss:3.3505 train_time:134057ms step_avg:97.50ms -step:1376/1695 train_time:134084ms step_avg:97.45ms -step:1377/1695 train_time:134163ms step_avg:97.43ms -step:1378/1695 train_time:134261ms step_avg:97.43ms -step:1379/1695 train_time:134359ms step_avg:97.43ms -step:1380/1695 train_time:134456ms step_avg:97.43ms -step:1381/1695 train_time:134781ms step_avg:97.60ms -step:1382/1695 train_time:134987ms step_avg:97.68ms -step:1383/1695 train_time:135083ms step_avg:97.67ms -step:1384/1695 train_time:135179ms step_avg:97.67ms -step:1385/1695 train_time:135276ms step_avg:97.67ms -step:1386/1695 train_time:135374ms step_avg:97.67ms -step:1387/1695 train_time:135471ms step_avg:97.67ms -step:1388/1695 train_time:135568ms step_avg:97.67ms -step:1389/1695 train_time:135665ms step_avg:97.67ms -step:1390/1695 train_time:135763ms step_avg:97.67ms -step:1391/1695 train_time:135867ms step_avg:97.68ms -step:1392/1695 train_time:135968ms step_avg:97.68ms -step:1393/1695 train_time:136065ms step_avg:97.68ms -step:1394/1695 train_time:136163ms step_avg:97.68ms -step:1395/1695 train_time:136260ms step_avg:97.68ms -step:1396/1695 train_time:136356ms step_avg:97.68ms -step:1397/1695 train_time:136453ms step_avg:97.68ms -step:1398/1695 train_time:136549ms step_avg:97.67ms -step:1399/1695 train_time:136646ms step_avg:97.67ms -step:1400/1695 train_time:136743ms step_avg:97.67ms -step:1401/1695 train_time:136841ms step_avg:97.67ms -step:1402/1695 train_time:136940ms step_avg:97.67ms -step:1403/1695 train_time:137039ms step_avg:97.68ms -step:1404/1695 train_time:137137ms step_avg:97.68ms -step:1405/1695 train_time:137234ms step_avg:97.68ms -step:1406/1695 train_time:137332ms step_avg:97.68ms -step:1407/1695 train_time:137429ms step_avg:97.67ms -step:1408/1695 train_time:137525ms step_avg:97.67ms -step:1409/1695 train_time:137622ms step_avg:97.67ms -step:1410/1695 train_time:137719ms step_avg:97.67ms -step:1411/1695 train_time:137817ms step_avg:97.67ms -step:1412/1695 train_time:137916ms step_avg:97.67ms -step:1413/1695 train_time:138014ms step_avg:97.67ms -step:1414/1695 train_time:138113ms step_avg:97.68ms -step:1415/1695 train_time:138212ms step_avg:97.68ms -step:1416/1695 train_time:138309ms step_avg:97.68ms -step:1417/1695 train_time:138405ms step_avg:97.67ms -step:1418/1695 train_time:138501ms step_avg:97.67ms -step:1419/1695 train_time:138598ms step_avg:97.67ms -step:1420/1695 train_time:138696ms step_avg:97.67ms -step:1421/1695 train_time:138794ms step_avg:97.67ms -step:1422/1695 train_time:138893ms step_avg:97.67ms -step:1423/1695 train_time:138990ms step_avg:97.67ms -step:1424/1695 train_time:139089ms step_avg:97.67ms -step:1425/1695 train_time:139188ms step_avg:97.68ms -step:1426/1695 train_time:139286ms step_avg:97.68ms -step:1427/1695 train_time:139382ms step_avg:97.68ms -step:1428/1695 train_time:139479ms step_avg:97.67ms -step:1429/1695 train_time:139577ms step_avg:97.67ms -step:1430/1695 train_time:139675ms step_avg:97.67ms -step:1431/1695 train_time:139772ms step_avg:97.67ms -step:1432/1695 train_time:139871ms step_avg:97.68ms -step:1433/1695 train_time:139969ms step_avg:97.68ms -step:1434/1695 train_time:140067ms step_avg:97.68ms -step:1435/1695 train_time:140164ms step_avg:97.68ms -step:1436/1695 train_time:140261ms step_avg:97.68ms -step:1437/1695 train_time:140358ms step_avg:97.67ms -step:1438/1695 train_time:140455ms step_avg:97.67ms -step:1439/1695 train_time:140553ms step_avg:97.67ms -step:1440/1695 train_time:140651ms step_avg:97.67ms -step:1441/1695 train_time:140749ms step_avg:97.67ms -step:1442/1695 train_time:140847ms step_avg:97.67ms -step:1443/1695 train_time:140945ms step_avg:97.68ms -step:1444/1695 train_time:141043ms step_avg:97.68ms -step:1445/1695 train_time:141140ms step_avg:97.67ms -step:1446/1695 train_time:141238ms step_avg:97.67ms -step:1447/1695 train_time:141335ms step_avg:97.67ms -step:1448/1695 train_time:141433ms step_avg:97.67ms -step:1449/1695 train_time:141530ms step_avg:97.67ms -step:1450/1695 train_time:141628ms step_avg:97.67ms -step:1451/1695 train_time:141725ms step_avg:97.67ms -step:1452/1695 train_time:141822ms step_avg:97.67ms -step:1453/1695 train_time:141920ms step_avg:97.67ms -step:1454/1695 train_time:142019ms step_avg:97.67ms -step:1455/1695 train_time:142118ms step_avg:97.68ms -step:1456/1695 train_time:142215ms step_avg:97.68ms -step:1457/1695 train_time:142313ms step_avg:97.68ms -step:1458/1695 train_time:142410ms step_avg:97.68ms -step:1459/1695 train_time:142507ms step_avg:97.67ms -step:1460/1695 train_time:142604ms step_avg:97.67ms -step:1461/1695 train_time:142701ms step_avg:97.67ms -step:1462/1695 train_time:142798ms step_avg:97.67ms -step:1463/1695 train_time:142896ms step_avg:97.67ms -step:1464/1695 train_time:142994ms step_avg:97.67ms -step:1465/1695 train_time:143092ms step_avg:97.67ms -step:1466/1695 train_time:143190ms step_avg:97.67ms -step:1467/1695 train_time:143287ms step_avg:97.67ms -step:1468/1695 train_time:143385ms step_avg:97.67ms -step:1469/1695 train_time:143482ms step_avg:97.67ms -step:1470/1695 train_time:143579ms step_avg:97.67ms -step:1471/1695 train_time:143678ms step_avg:97.67ms -step:1472/1695 train_time:143776ms step_avg:97.67ms -step:1473/1695 train_time:143873ms step_avg:97.67ms -step:1474/1695 train_time:143972ms step_avg:97.67ms -step:1475/1695 train_time:144070ms step_avg:97.67ms -step:1476/1695 train_time:144167ms step_avg:97.67ms -step:1477/1695 train_time:144265ms step_avg:97.67ms -step:1478/1695 train_time:144361ms step_avg:97.67ms -step:1479/1695 train_time:144459ms step_avg:97.67ms -step:1480/1695 train_time:144555ms step_avg:97.67ms -step:1481/1695 train_time:144653ms step_avg:97.67ms -step:1482/1695 train_time:144751ms step_avg:97.67ms -step:1483/1695 train_time:144849ms step_avg:97.67ms -step:1484/1695 train_time:144947ms step_avg:97.67ms -step:1485/1695 train_time:145044ms step_avg:97.67ms -step:1486/1695 train_time:145141ms step_avg:97.67ms -step:1487/1695 train_time:145239ms step_avg:97.67ms -step:1488/1695 train_time:145337ms step_avg:97.67ms -step:1489/1695 train_time:145435ms step_avg:97.67ms -step:1490/1695 train_time:145533ms step_avg:97.67ms -step:1491/1695 train_time:145630ms step_avg:97.67ms -step:1492/1695 train_time:145728ms step_avg:97.67ms -step:1493/1695 train_time:145825ms step_avg:97.67ms -step:1494/1695 train_time:145922ms step_avg:97.67ms -step:1495/1695 train_time:146020ms step_avg:97.67ms -step:1496/1695 train_time:146118ms step_avg:97.67ms -step:1497/1695 train_time:146216ms step_avg:97.67ms -step:1498/1695 train_time:146315ms step_avg:97.67ms -step:1499/1695 train_time:146413ms step_avg:97.67ms -step:1500/1695 train_time:146511ms step_avg:97.67ms -step:1500/1695 val_loss:3.3179 train_time:146606ms step_avg:97.74ms -step:1501/1695 train_time:146632ms step_avg:97.69ms -step:1502/1695 train_time:146715ms step_avg:97.68ms -step:1503/1695 train_time:146816ms step_avg:97.68ms -step:1504/1695 train_time:146914ms step_avg:97.68ms -step:1505/1695 train_time:147011ms step_avg:97.68ms -step:1506/1695 train_time:147108ms step_avg:97.68ms -step:1507/1695 train_time:147205ms step_avg:97.68ms -step:1508/1695 train_time:147301ms step_avg:97.68ms -step:1509/1695 train_time:147397ms step_avg:97.68ms -step:1510/1695 train_time:147494ms step_avg:97.68ms -step:1511/1695 train_time:147596ms step_avg:97.68ms -step:1512/1695 train_time:147698ms step_avg:97.68ms -step:1513/1695 train_time:147797ms step_avg:97.68ms -step:1514/1695 train_time:147896ms step_avg:97.69ms -step:1515/1695 train_time:147994ms step_avg:97.69ms -step:1516/1695 train_time:148092ms step_avg:97.69ms -step:1517/1695 train_time:148189ms step_avg:97.69ms -step:1518/1695 train_time:148286ms step_avg:97.68ms -step:1519/1695 train_time:148382ms step_avg:97.68ms -step:1520/1695 train_time:148478ms step_avg:97.68ms -step:1521/1695 train_time:148577ms step_avg:97.68ms -step:1522/1695 train_time:148678ms step_avg:97.69ms -step:1523/1695 train_time:148777ms step_avg:97.69ms -step:1524/1695 train_time:148876ms step_avg:97.69ms -step:1525/1695 train_time:148975ms step_avg:97.69ms -step:1526/1695 train_time:149074ms step_avg:97.69ms -step:1527/1695 train_time:149172ms step_avg:97.69ms -step:1528/1695 train_time:149269ms step_avg:97.69ms -step:1529/1695 train_time:149367ms step_avg:97.69ms -step:1530/1695 train_time:149465ms step_avg:97.69ms -step:1531/1695 train_time:149562ms step_avg:97.69ms -step:1532/1695 train_time:149660ms step_avg:97.69ms -step:1533/1695 train_time:149758ms step_avg:97.69ms -step:1534/1695 train_time:149856ms step_avg:97.69ms -step:1535/1695 train_time:149955ms step_avg:97.69ms -step:1536/1695 train_time:150054ms step_avg:97.69ms -step:1537/1695 train_time:150152ms step_avg:97.69ms -step:1538/1695 train_time:150249ms step_avg:97.69ms -step:1539/1695 train_time:150346ms step_avg:97.69ms -step:1540/1695 train_time:150443ms step_avg:97.69ms -step:1541/1695 train_time:150540ms step_avg:97.69ms -step:1542/1695 train_time:150638ms step_avg:97.69ms -step:1543/1695 train_time:150736ms step_avg:97.69ms -step:1544/1695 train_time:150835ms step_avg:97.69ms -step:1545/1695 train_time:150932ms step_avg:97.69ms -step:1546/1695 train_time:151030ms step_avg:97.69ms -step:1547/1695 train_time:151128ms step_avg:97.69ms -step:1548/1695 train_time:151225ms step_avg:97.69ms -step:1549/1695 train_time:151322ms step_avg:97.69ms -step:1550/1695 train_time:151419ms step_avg:97.69ms -step:1551/1695 train_time:151517ms step_avg:97.69ms -step:1552/1695 train_time:151888ms step_avg:97.87ms -step:1553/1695 train_time:151963ms step_avg:97.85ms -step:1554/1695 train_time:152058ms step_avg:97.85ms -step:1555/1695 train_time:152155ms step_avg:97.85ms -step:1556/1695 train_time:152252ms step_avg:97.85ms -step:1557/1695 train_time:152349ms step_avg:97.85ms -step:1558/1695 train_time:152445ms step_avg:97.85ms -step:1559/1695 train_time:152541ms step_avg:97.85ms -step:1560/1695 train_time:152638ms step_avg:97.84ms -step:1561/1695 train_time:152735ms step_avg:97.84ms -step:1562/1695 train_time:152839ms step_avg:97.85ms -step:1563/1695 train_time:152939ms step_avg:97.85ms -step:1564/1695 train_time:153039ms step_avg:97.85ms -step:1565/1695 train_time:153136ms step_avg:97.85ms -step:1566/1695 train_time:153235ms step_avg:97.85ms -step:1567/1695 train_time:153331ms step_avg:97.85ms -step:1568/1695 train_time:153428ms step_avg:97.85ms -step:1569/1695 train_time:153525ms step_avg:97.85ms -step:1570/1695 train_time:153622ms step_avg:97.85ms -step:1571/1695 train_time:153720ms step_avg:97.85ms -step:1572/1695 train_time:153819ms step_avg:97.85ms -step:1573/1695 train_time:153917ms step_avg:97.85ms -step:1574/1695 train_time:154017ms step_avg:97.85ms -step:1575/1695 train_time:154115ms step_avg:97.85ms -step:1576/1695 train_time:154215ms step_avg:97.85ms -step:1577/1695 train_time:154312ms step_avg:97.85ms -step:1578/1695 train_time:154410ms step_avg:97.85ms -step:1579/1695 train_time:154507ms step_avg:97.85ms -step:1580/1695 train_time:154604ms step_avg:97.85ms -step:1581/1695 train_time:154701ms step_avg:97.85ms -step:1582/1695 train_time:154798ms step_avg:97.85ms -step:1583/1695 train_time:154897ms step_avg:97.85ms -step:1584/1695 train_time:154996ms step_avg:97.85ms -step:1585/1695 train_time:155096ms step_avg:97.85ms -step:1586/1695 train_time:155195ms step_avg:97.85ms -step:1587/1695 train_time:155293ms step_avg:97.85ms -step:1588/1695 train_time:155391ms step_avg:97.85ms -step:1589/1695 train_time:155489ms step_avg:97.85ms -step:1590/1695 train_time:155586ms step_avg:97.85ms -step:1591/1695 train_time:155682ms step_avg:97.85ms -step:1592/1695 train_time:155779ms step_avg:97.85ms -step:1593/1695 train_time:155876ms step_avg:97.85ms -step:1594/1695 train_time:155976ms step_avg:97.85ms -step:1595/1695 train_time:156075ms step_avg:97.85ms -step:1596/1695 train_time:156175ms step_avg:97.85ms -step:1597/1695 train_time:156275ms step_avg:97.86ms -step:1598/1695 train_time:156375ms step_avg:97.86ms -step:1599/1695 train_time:156473ms step_avg:97.86ms -step:1600/1695 train_time:156571ms step_avg:97.86ms -step:1601/1695 train_time:156669ms step_avg:97.86ms -step:1602/1695 train_time:156767ms step_avg:97.86ms -step:1603/1695 train_time:156864ms step_avg:97.86ms -step:1604/1695 train_time:156962ms step_avg:97.86ms -step:1605/1695 train_time:157060ms step_avg:97.86ms -step:1606/1695 train_time:157159ms step_avg:97.86ms -step:1607/1695 train_time:157257ms step_avg:97.86ms -step:1608/1695 train_time:157354ms step_avg:97.86ms -step:1609/1695 train_time:157453ms step_avg:97.86ms -step:1610/1695 train_time:157551ms step_avg:97.86ms -step:1611/1695 train_time:157649ms step_avg:97.86ms -step:1612/1695 train_time:157747ms step_avg:97.86ms -step:1613/1695 train_time:157844ms step_avg:97.86ms -step:1614/1695 train_time:157942ms step_avg:97.86ms -step:1615/1695 train_time:158039ms step_avg:97.86ms -step:1616/1695 train_time:158137ms step_avg:97.86ms -step:1617/1695 train_time:158235ms step_avg:97.86ms -step:1618/1695 train_time:158333ms step_avg:97.86ms -step:1619/1695 train_time:158432ms step_avg:97.86ms -step:1620/1695 train_time:158530ms step_avg:97.86ms -step:1621/1695 train_time:158628ms step_avg:97.86ms -step:1622/1695 train_time:158726ms step_avg:97.86ms -step:1623/1695 train_time:158824ms step_avg:97.86ms -step:1624/1695 train_time:158922ms step_avg:97.86ms -step:1625/1695 train_time:159019ms step_avg:97.86ms -step:1625/1695 val_loss:3.2905 train_time:159114ms step_avg:97.92ms -step:1626/1695 train_time:159139ms step_avg:97.87ms -step:1627/1695 train_time:159222ms step_avg:97.86ms -step:1628/1695 train_time:159320ms step_avg:97.86ms -step:1629/1695 train_time:159418ms step_avg:97.86ms -step:1630/1695 train_time:159515ms step_avg:97.86ms -step:1631/1695 train_time:159612ms step_avg:97.86ms -step:1632/1695 train_time:159709ms step_avg:97.86ms -step:1633/1695 train_time:159806ms step_avg:97.86ms -step:1634/1695 train_time:159902ms step_avg:97.86ms -step:1635/1695 train_time:159999ms step_avg:97.86ms -step:1636/1695 train_time:160099ms step_avg:97.86ms -step:1637/1695 train_time:160200ms step_avg:97.86ms -step:1638/1695 train_time:160300ms step_avg:97.86ms -step:1639/1695 train_time:160398ms step_avg:97.86ms -step:1640/1695 train_time:160495ms step_avg:97.86ms -step:1641/1695 train_time:160592ms step_avg:97.86ms -step:1642/1695 train_time:160689ms step_avg:97.86ms -step:1643/1695 train_time:160786ms step_avg:97.86ms -step:1644/1695 train_time:160882ms step_avg:97.86ms -step:1645/1695 train_time:160980ms step_avg:97.86ms -step:1646/1695 train_time:161078ms step_avg:97.86ms -step:1647/1695 train_time:161177ms step_avg:97.86ms -step:1648/1695 train_time:161277ms step_avg:97.86ms -step:1649/1695 train_time:161378ms step_avg:97.86ms -step:1650/1695 train_time:161476ms step_avg:97.86ms -step:1651/1695 train_time:161573ms step_avg:97.86ms -step:1652/1695 train_time:161670ms step_avg:97.86ms -step:1653/1695 train_time:161767ms step_avg:97.86ms -step:1654/1695 train_time:161864ms step_avg:97.86ms -step:1655/1695 train_time:161961ms step_avg:97.86ms -step:1656/1695 train_time:162059ms step_avg:97.86ms -step:1657/1695 train_time:162157ms step_avg:97.86ms -step:1658/1695 train_time:162256ms step_avg:97.86ms -step:1659/1695 train_time:162356ms step_avg:97.86ms -step:1660/1695 train_time:162456ms step_avg:97.86ms -step:1661/1695 train_time:162555ms step_avg:97.87ms -step:1662/1695 train_time:162654ms step_avg:97.87ms -step:1663/1695 train_time:162751ms step_avg:97.87ms -step:1664/1695 train_time:162849ms step_avg:97.87ms -step:1665/1695 train_time:162946ms step_avg:97.87ms -step:1666/1695 train_time:163044ms step_avg:97.87ms -step:1667/1695 train_time:163142ms step_avg:97.87ms -step:1668/1695 train_time:163239ms step_avg:97.87ms -step:1669/1695 train_time:163337ms step_avg:97.87ms -step:1670/1695 train_time:163435ms step_avg:97.87ms -step:1671/1695 train_time:163534ms step_avg:97.87ms -step:1672/1695 train_time:163633ms step_avg:97.87ms -step:1673/1695 train_time:163731ms step_avg:97.87ms -step:1674/1695 train_time:163829ms step_avg:97.87ms -step:1675/1695 train_time:163927ms step_avg:97.87ms -step:1676/1695 train_time:164024ms step_avg:97.87ms -step:1677/1695 train_time:164122ms step_avg:97.87ms -step:1678/1695 train_time:164219ms step_avg:97.87ms -step:1679/1695 train_time:164317ms step_avg:97.87ms -step:1680/1695 train_time:164415ms step_avg:97.87ms -step:1681/1695 train_time:164513ms step_avg:97.87ms -step:1682/1695 train_time:164612ms step_avg:97.87ms -step:1683/1695 train_time:164710ms step_avg:97.87ms -step:1684/1695 train_time:164809ms step_avg:97.87ms -step:1685/1695 train_time:164906ms step_avg:97.87ms -step:1686/1695 train_time:165004ms step_avg:97.87ms -step:1687/1695 train_time:165101ms step_avg:97.87ms -step:1688/1695 train_time:165199ms step_avg:97.87ms -step:1689/1695 train_time:165296ms step_avg:97.87ms -step:1690/1695 train_time:165393ms step_avg:97.87ms -step:1691/1695 train_time:165491ms step_avg:97.87ms -step:1692/1695 train_time:165589ms step_avg:97.87ms -step:1693/1695 train_time:165686ms step_avg:97.87ms -step:1694/1695 train_time:165783ms step_avg:97.86ms -step:1695/1695 train_time:165881ms step_avg:97.86ms -step:1695/1695 val_loss:3.2790 train_time:165977ms step_avg:97.92ms -peak memory allocated: 34000 MiB reserved: 49756 MiB diff --git a/records/082725_FA3/README.md b/records/082725_FA3/README.md deleted file mode 100644 index a4079630d..000000000 --- a/records/082725_FA3/README.md +++ /dev/null @@ -1,147 +0,0 @@ -# New record 08/27/25 - -This submission includes recent WR changes by -@ClassicLarry [(08/23/25)](https://github.com/ClassicLarry/modded-nanogpt/tree/master/records/082325_SparseAttnGate) -and @byronxu99 [(07/18/25)](https://github.com/KellerJordan/modded-nanogpt/pull/109). - -The main idea of this record is to use input tensors with `batch_size > 1` throughout our training run. -Increasing `batch_size` increases GPU utilization and allows us to use shorter input sequences for training. -However, since Flex Attention's is inefficient for `batch_size > 1`, we use [Flash Attention v3](https://github.com/Dao-AILab/flash-attention). -The official version of this module is incompatible with `torch.compile` and causes graph breaks. -However, a [recent PR](https://github.com/Dao-AILab/flash-attention/pull/1769) by -[@guilhermeleobas](https://github.com/guilhermeleobas) addresses this issue. - - -## Timing and Validation - -Validated over 7 runs: -- In 1695 training steps, this run achieves a loss <3.28 (`p=0.0031`) -- In 166.10 seconds on average, or <166.25 seconds (`p=0.0024`), - -``` -import scipy.stats -import torch -import numpy as np - -accs = [ - 3.2769, 3.2782, 3.2790, 3.2791, 3.2791, 3.2780, 3.2782 -] - -times = [ - 166.247, 166.117, 165.977, 166.135, 166.045, 166.044, 166.157 -] - -print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) -# p=0.0008 - -print('p=%.4f' % scipy.stats.ttest_1samp(times, 166.25, alternative='less').pvalue) -# p=0.0024 - -print(f"{np.mean(times):.4f}") -# 166.1031 -``` - -In my timing, this is a 2.1 second mean improvement over [PR#117])(https://github.com/KellerJordan/modded-nanogpt/pull/117). -The number of steps can also probably be brought down by 5-15 while achieving loss <3.28. - -I used SXM5 8 x H100 via Prime Intellect for validation compute. - -## Further Details - -### Motivation - -PyTorch's Flex Attention experiences a slowdown >10% wallclock for inputs with `batch_size > 1`. -As such, previous records would train on very long sequence lengths (`48 * 1024`) with no batch dimension. -Attention is approximately `O(|seq_len|^2 x |batch_size|)`, so this is theoretically bad, -but it was mitigated by using aggressive blocking masking. -Attention used a `block_mask` which only grew at most to `1664` tokens (and was often shorter due to document masking). -However, GPU utilization for attention is higher when tokens are distributed along the batch dimension. - - -Additionally, increasing the batch size allows us to decrease sequence length while maintaining the total -number of tokens processed per step. -WR#26 by @ClassicLarry found that validation loss decreases when we train only -on sequences beginning with the Beginning of Sequence token (``). -Decreasing the sequence length ensures makes it more likely that `` is present in the attention window. -In order generate batches where each sequence begins with ``, I have created the helper class -`EOSBatchFinder`. This class pre-indexes shards with the location of `` for slight speedups. - -### Flash Attention 3 - -Most of the Hopper-specific benefits in Flash Attention 3 are incorporated into -PyTorch's Flex Attention already. However, the latter implementation is fastest with `batch_size == 1`, -Flash Attention 3 is as fast as Flex Attention for 1 dimensional input sequences, and increases -in speed as we distribute tokens along the batch dimension. -I measured a 9% wallclock decrease for FA3 when using an optimal ratio of batch dimension to sequence length -(`24: 2048`) over a single batch dimension (`1: 49152`) (on a single Hopper H100). - -As mentioned above, we need to use an unmerged PR in order to use FA3 with `torch.compile`. -You can build the wheel like so: - -``` -pip install -U pip wheel setuptools ninja numpy packaging psutil - -git clone https://github.com/guilhermeleobas/flash-attention.git -cd flash-attention/hopper -git switch guilhermeleobas/fa3-compile - -export MAX_JOBS=32 # Can increase based on machine -export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch -export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only -export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8 -export FLASH_ATTENTION_DISABLE_HDIM64=TRUE # NanoGPT only uses HDIM = 128 -export FLASH_ATTENTION_DISABLE_HDIM96=TRUE -export FLASH_ATTENTION_DISABLE_HDIM192=TRUE -export FLASH_ATTENTION_DISABLE_HDIM256=TRUE - -python setup.py bdist_wheel -``` - -Additionally, I have uploaded a prebuilt wheel -[here](https://github.com/varunneal/flash-attention/releases/tag/v3.0.0b1-alpha), -though it will likely be faster to build it yourself than download this wheel. - -For exact reproduction, I recommend that you install Torch Nightly 2.9.0.dev20250718 and -install the FA3 wheel afterward: - -``` -pip install --pre "torch==2.9.0.dev20250718+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126 - -# typical path to FA3 Wheel -pip install flash-attention/hopper/dist/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl -``` - -For me, Torch Nightly 2.9.0.dev20250713 was incompatible with PR#109. - -### Attention Masks - -Unfortunately, Flash Attention does not support complex Block Masks like Flex Attention. -Therefore, `create_blockmasks` was removed. Instead, we only are given the parameter `window_size` -where we can specify the number of left tokens to attend to. - -I kept the existing long-short sliding window block mask pattern, as well as the idea -that the window sizes should linearly increase over the length of the training run. -To aid with this, I modified `get_lr(step)` to instead be `get_lr_and_ws(step)`. -Additionally, I added a hyperparameter `ws_schedule` which specifies what the -longer window size should be during each portion of the run. I additionally added the -size of blocks in a window as a hyperparameter `bandwidth=128`. - -I have picked a linear schedule with three steps: `ws_schedule=(3, 7, 11)`. -Currently, `torch.compile` creates a new compilation graph per each step in `ws_schedule`. -Therefore, each graph needs to be warmed up separately. I have increased the number -of warmup steps from `10` to `60`. The compile time is dominated by the first iteration -so this will take approximately `len(ws_schedule)` times longer than before. - -Removing document masking had a noticeably negative impact on validation loss, -however the benefits of a short sequence length counteract this. - -### Potential Improvements - -- Batch size scheduling: Previously, the block mask acted as a proxy for batch size. -Now block size can be controlled explicitly and sequenced according to critical batch -size theory. I have added code in `distributed_data_generator` that allows for changing the -batch size and sequence length yielded after the generator is created. -- The current block mask window schedule `(3, 7, 11)` can almost certainly be improved upon. -- Hyperparameter tuning might change with smaller sequence length. Rotary base, validation sequence length, learning rates -etc. should be re-tuned. I haven't done that for this run. -- FA3 has additional features over Flex Attention that may be useful. \ No newline at end of file diff --git a/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt b/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt deleted file mode 100644 index 7a5ed0b1c..000000000 --- a/records/082725_FA3/ba9be2f3-1e6f-4a1a-827e-a47a702c67b0.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:58:09 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 30C P0 115W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 33C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 34C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 31C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 30C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 35C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 33C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.02ms -step:1/1695 train_time:516ms step_avg:515.52ms -step:2/1695 train_time:539ms step_avg:269.65ms -step:3/1695 train_time:612ms step_avg:203.90ms -step:4/1695 train_time:704ms step_avg:175.97ms -step:5/1695 train_time:797ms step_avg:159.42ms -step:6/1695 train_time:891ms step_avg:148.48ms -step:7/1695 train_time:984ms step_avg:140.60ms -step:8/1695 train_time:1078ms step_avg:134.78ms -step:9/1695 train_time:1172ms step_avg:130.23ms -step:10/1695 train_time:1265ms step_avg:126.49ms -step:11/1695 train_time:1359ms step_avg:123.52ms -step:12/1695 train_time:1457ms step_avg:121.44ms -step:13/1695 train_time:1555ms step_avg:119.64ms -step:14/1695 train_time:1650ms step_avg:117.89ms -step:15/1695 train_time:1745ms step_avg:116.30ms -step:16/1695 train_time:1839ms step_avg:114.95ms -step:17/1695 train_time:1933ms step_avg:113.72ms -step:18/1695 train_time:2027ms step_avg:112.62ms -step:19/1695 train_time:2122ms step_avg:111.67ms -step:20/1695 train_time:2216ms step_avg:110.82ms -step:21/1695 train_time:2311ms step_avg:110.03ms -step:22/1695 train_time:2405ms step_avg:109.33ms -step:23/1695 train_time:2501ms step_avg:108.73ms -step:24/1695 train_time:2597ms step_avg:108.21ms -step:25/1695 train_time:2693ms step_avg:107.73ms -step:26/1695 train_time:2788ms step_avg:107.23ms -step:27/1695 train_time:2882ms step_avg:106.74ms -step:28/1695 train_time:2977ms step_avg:106.33ms -step:29/1695 train_time:3071ms step_avg:105.90ms -step:30/1695 train_time:3165ms step_avg:105.50ms -step:31/1695 train_time:3259ms step_avg:105.14ms -step:32/1695 train_time:3355ms step_avg:104.84ms -step:33/1695 train_time:3449ms step_avg:104.53ms -step:34/1695 train_time:3545ms step_avg:104.26ms -step:35/1695 train_time:3640ms step_avg:104.01ms -step:36/1695 train_time:3736ms step_avg:103.78ms -step:37/1695 train_time:3831ms step_avg:103.53ms -step:38/1695 train_time:3925ms step_avg:103.28ms -step:39/1695 train_time:4019ms step_avg:103.06ms -step:40/1695 train_time:4113ms step_avg:102.84ms -step:41/1695 train_time:4207ms step_avg:102.60ms -step:42/1695 train_time:4301ms step_avg:102.40ms -step:43/1695 train_time:4396ms step_avg:102.23ms -step:44/1695 train_time:4491ms step_avg:102.07ms -step:45/1695 train_time:4587ms step_avg:101.92ms -step:46/1695 train_time:4681ms step_avg:101.76ms -step:47/1695 train_time:4777ms step_avg:101.63ms -step:48/1695 train_time:4873ms step_avg:101.53ms -step:49/1695 train_time:4966ms step_avg:101.35ms -step:50/1695 train_time:5060ms step_avg:101.21ms -step:51/1695 train_time:5155ms step_avg:101.07ms -step:52/1695 train_time:5249ms step_avg:100.94ms -step:53/1695 train_time:5343ms step_avg:100.82ms -step:54/1695 train_time:5439ms step_avg:100.72ms -step:55/1695 train_time:5534ms step_avg:100.62ms -step:56/1695 train_time:5629ms step_avg:100.51ms -step:57/1695 train_time:5723ms step_avg:100.41ms -step:58/1695 train_time:5818ms step_avg:100.32ms -step:59/1695 train_time:5913ms step_avg:100.23ms -step:60/1695 train_time:6007ms step_avg:100.11ms -step:61/1695 train_time:6100ms step_avg:100.00ms -step:62/1695 train_time:6196ms step_avg:99.93ms -step:63/1695 train_time:6290ms step_avg:99.84ms -step:64/1695 train_time:6384ms step_avg:99.75ms -step:65/1695 train_time:6479ms step_avg:99.68ms -step:66/1695 train_time:6573ms step_avg:99.59ms -step:67/1695 train_time:6667ms step_avg:99.51ms -step:68/1695 train_time:6762ms step_avg:99.44ms -step:69/1695 train_time:6856ms step_avg:99.37ms -step:70/1695 train_time:6950ms step_avg:99.29ms -step:71/1695 train_time:7044ms step_avg:99.21ms -step:72/1695 train_time:7139ms step_avg:99.16ms -step:73/1695 train_time:7234ms step_avg:99.10ms -step:74/1695 train_time:7329ms step_avg:99.04ms -step:75/1695 train_time:7423ms step_avg:98.98ms -step:76/1695 train_time:7519ms step_avg:98.94ms -step:77/1695 train_time:7614ms step_avg:98.88ms -step:78/1695 train_time:7709ms step_avg:98.83ms -step:79/1695 train_time:7803ms step_avg:98.77ms -step:80/1695 train_time:7897ms step_avg:98.71ms -step:81/1695 train_time:7991ms step_avg:98.66ms -step:82/1695 train_time:8085ms step_avg:98.60ms -step:83/1695 train_time:8179ms step_avg:98.55ms -step:84/1695 train_time:8274ms step_avg:98.50ms -step:85/1695 train_time:8368ms step_avg:98.45ms -step:86/1695 train_time:8462ms step_avg:98.39ms -step:87/1695 train_time:8558ms step_avg:98.36ms -step:88/1695 train_time:8653ms step_avg:98.33ms -step:89/1695 train_time:8747ms step_avg:98.28ms -step:90/1695 train_time:8841ms step_avg:98.24ms -step:91/1695 train_time:8936ms step_avg:98.20ms -step:92/1695 train_time:9031ms step_avg:98.16ms -step:93/1695 train_time:9125ms step_avg:98.12ms -step:94/1695 train_time:9219ms step_avg:98.08ms -step:95/1695 train_time:9313ms step_avg:98.03ms -step:96/1695 train_time:9406ms step_avg:97.98ms -step:97/1695 train_time:9500ms step_avg:97.94ms -step:98/1695 train_time:9596ms step_avg:97.91ms -step:99/1695 train_time:9690ms step_avg:97.88ms -step:100/1695 train_time:9784ms step_avg:97.84ms -step:101/1695 train_time:9878ms step_avg:97.80ms -step:102/1695 train_time:9973ms step_avg:97.78ms -step:103/1695 train_time:10067ms step_avg:97.74ms -step:104/1695 train_time:10162ms step_avg:97.71ms -step:105/1695 train_time:10256ms step_avg:97.67ms -step:106/1695 train_time:10350ms step_avg:97.64ms -step:107/1695 train_time:10444ms step_avg:97.60ms -step:108/1695 train_time:10540ms step_avg:97.59ms -step:109/1695 train_time:10635ms step_avg:97.57ms -step:110/1695 train_time:10731ms step_avg:97.55ms -step:111/1695 train_time:10825ms step_avg:97.52ms -step:112/1695 train_time:10920ms step_avg:97.50ms -step:113/1695 train_time:11014ms step_avg:97.47ms -step:114/1695 train_time:11108ms step_avg:97.44ms -step:115/1695 train_time:11201ms step_avg:97.40ms -step:116/1695 train_time:11297ms step_avg:97.38ms -step:117/1695 train_time:11391ms step_avg:97.36ms -step:118/1695 train_time:11485ms step_avg:97.33ms -step:119/1695 train_time:11580ms step_avg:97.31ms -step:120/1695 train_time:11675ms step_avg:97.29ms -step:121/1695 train_time:11769ms step_avg:97.27ms -step:122/1695 train_time:11863ms step_avg:97.24ms -step:123/1695 train_time:11959ms step_avg:97.22ms -step:124/1695 train_time:12054ms step_avg:97.21ms -step:125/1695 train_time:12148ms step_avg:97.18ms -step:125/1695 val_loss:4.3195 train_time:12239ms step_avg:97.92ms -step:126/1695 train_time:12266ms step_avg:97.35ms -step:127/1695 train_time:12343ms step_avg:97.19ms -step:128/1695 train_time:12442ms step_avg:97.21ms -step:129/1695 train_time:12537ms step_avg:97.19ms -step:130/1695 train_time:12631ms step_avg:97.16ms -step:131/1695 train_time:12725ms step_avg:97.14ms -step:132/1695 train_time:12818ms step_avg:97.11ms -step:133/1695 train_time:12911ms step_avg:97.07ms -step:134/1695 train_time:13005ms step_avg:97.05ms -step:135/1695 train_time:13098ms step_avg:97.02ms -step:136/1695 train_time:13191ms step_avg:97.00ms -step:137/1695 train_time:13287ms step_avg:96.99ms -step:138/1695 train_time:13385ms step_avg:97.00ms -step:139/1695 train_time:13481ms step_avg:96.98ms -step:140/1695 train_time:13575ms step_avg:96.97ms -step:141/1695 train_time:13669ms step_avg:96.94ms -step:142/1695 train_time:13763ms step_avg:96.92ms -step:143/1695 train_time:13856ms step_avg:96.90ms -step:144/1695 train_time:13949ms step_avg:96.87ms -step:145/1695 train_time:14043ms step_avg:96.85ms -step:146/1695 train_time:14136ms step_avg:96.82ms -step:147/1695 train_time:14230ms step_avg:96.80ms -step:148/1695 train_time:14326ms step_avg:96.79ms -step:149/1695 train_time:14422ms step_avg:96.79ms -step:150/1695 train_time:14517ms step_avg:96.78ms -step:151/1695 train_time:14612ms step_avg:96.77ms -step:152/1695 train_time:14707ms step_avg:96.75ms -step:153/1695 train_time:14801ms step_avg:96.74ms -step:154/1695 train_time:14895ms step_avg:96.72ms -step:155/1695 train_time:14989ms step_avg:96.70ms -step:156/1695 train_time:15083ms step_avg:96.69ms -step:157/1695 train_time:15176ms step_avg:96.66ms -step:158/1695 train_time:15270ms step_avg:96.65ms -step:159/1695 train_time:15365ms step_avg:96.63ms -step:160/1695 train_time:15461ms step_avg:96.63ms -step:161/1695 train_time:15556ms step_avg:96.62ms -step:162/1695 train_time:15650ms step_avg:96.60ms -step:163/1695 train_time:15744ms step_avg:96.59ms -step:164/1695 train_time:15839ms step_avg:96.58ms -step:165/1695 train_time:15933ms step_avg:96.56ms -step:166/1695 train_time:16027ms step_avg:96.55ms -step:167/1695 train_time:16121ms step_avg:96.53ms -step:168/1695 train_time:16215ms step_avg:96.52ms -step:169/1695 train_time:16309ms step_avg:96.50ms -step:170/1695 train_time:16404ms step_avg:96.50ms -step:171/1695 train_time:16499ms step_avg:96.49ms -step:172/1695 train_time:16594ms step_avg:96.47ms -step:173/1695 train_time:16931ms step_avg:97.87ms -step:174/1695 train_time:17020ms step_avg:97.82ms -step:175/1695 train_time:17114ms step_avg:97.79ms -step:176/1695 train_time:17207ms step_avg:97.77ms -step:177/1695 train_time:17301ms step_avg:97.74ms -step:178/1695 train_time:17394ms step_avg:97.72ms -step:179/1695 train_time:17487ms step_avg:97.69ms -step:180/1695 train_time:17581ms step_avg:97.67ms -step:181/1695 train_time:17673ms step_avg:97.64ms -step:182/1695 train_time:17767ms step_avg:97.62ms -step:183/1695 train_time:17864ms step_avg:97.62ms -step:184/1695 train_time:17961ms step_avg:97.61ms -step:185/1695 train_time:18055ms step_avg:97.60ms -step:186/1695 train_time:18149ms step_avg:97.58ms -step:187/1695 train_time:18244ms step_avg:97.56ms -step:188/1695 train_time:18339ms step_avg:97.55ms -step:189/1695 train_time:18432ms step_avg:97.52ms -step:190/1695 train_time:18526ms step_avg:97.50ms -step:191/1695 train_time:18620ms step_avg:97.49ms -step:192/1695 train_time:18713ms step_avg:97.46ms -step:193/1695 train_time:18808ms step_avg:97.45ms -step:194/1695 train_time:18904ms step_avg:97.44ms -step:195/1695 train_time:19000ms step_avg:97.43ms -step:196/1695 train_time:19093ms step_avg:97.42ms -step:197/1695 train_time:19188ms step_avg:97.40ms -step:198/1695 train_time:19283ms step_avg:97.39ms -step:199/1695 train_time:19377ms step_avg:97.37ms -step:200/1695 train_time:19470ms step_avg:97.35ms -step:201/1695 train_time:19564ms step_avg:97.33ms -step:202/1695 train_time:19657ms step_avg:97.31ms -step:203/1695 train_time:19751ms step_avg:97.30ms -step:204/1695 train_time:19847ms step_avg:97.29ms -step:205/1695 train_time:19942ms step_avg:97.28ms -step:206/1695 train_time:20036ms step_avg:97.26ms -step:207/1695 train_time:20129ms step_avg:97.24ms -step:208/1695 train_time:20224ms step_avg:97.23ms -step:209/1695 train_time:20318ms step_avg:97.21ms -step:210/1695 train_time:20411ms step_avg:97.20ms -step:211/1695 train_time:20506ms step_avg:97.18ms -step:212/1695 train_time:20601ms step_avg:97.17ms -step:213/1695 train_time:20693ms step_avg:97.15ms -step:214/1695 train_time:20787ms step_avg:97.14ms -step:215/1695 train_time:20883ms step_avg:97.13ms -step:216/1695 train_time:20977ms step_avg:97.12ms -step:217/1695 train_time:21071ms step_avg:97.10ms -step:218/1695 train_time:21166ms step_avg:97.09ms -step:219/1695 train_time:21261ms step_avg:97.08ms -step:220/1695 train_time:21355ms step_avg:97.07ms -step:221/1695 train_time:21449ms step_avg:97.05ms -step:222/1695 train_time:21544ms step_avg:97.04ms -step:223/1695 train_time:21637ms step_avg:97.03ms -step:224/1695 train_time:21730ms step_avg:97.01ms -step:225/1695 train_time:21825ms step_avg:97.00ms -step:226/1695 train_time:21919ms step_avg:96.99ms -step:227/1695 train_time:22013ms step_avg:96.97ms -step:228/1695 train_time:22108ms step_avg:96.97ms -step:229/1695 train_time:22203ms step_avg:96.96ms -step:230/1695 train_time:22298ms step_avg:96.95ms -step:231/1695 train_time:22391ms step_avg:96.93ms -step:232/1695 train_time:22486ms step_avg:96.92ms -step:233/1695 train_time:22580ms step_avg:96.91ms -step:234/1695 train_time:22673ms step_avg:96.89ms -step:235/1695 train_time:22767ms step_avg:96.88ms -step:236/1695 train_time:22862ms step_avg:96.87ms -step:237/1695 train_time:22957ms step_avg:96.86ms -step:238/1695 train_time:23051ms step_avg:96.85ms -step:239/1695 train_time:23146ms step_avg:96.85ms -step:240/1695 train_time:23242ms step_avg:96.84ms -step:241/1695 train_time:23336ms step_avg:96.83ms -step:242/1695 train_time:23430ms step_avg:96.82ms -step:243/1695 train_time:23526ms step_avg:96.81ms -step:244/1695 train_time:23621ms step_avg:96.81ms -step:245/1695 train_time:23715ms step_avg:96.79ms -step:246/1695 train_time:23809ms step_avg:96.78ms -step:247/1695 train_time:23902ms step_avg:96.77ms -step:248/1695 train_time:23996ms step_avg:96.76ms -step:249/1695 train_time:24089ms step_avg:96.74ms -step:250/1695 train_time:24184ms step_avg:96.73ms -step:250/1695 val_loss:3.9759 train_time:24277ms step_avg:97.11ms -step:251/1695 train_time:24301ms step_avg:96.82ms -step:252/1695 train_time:24381ms step_avg:96.75ms -step:253/1695 train_time:24478ms step_avg:96.75ms -step:254/1695 train_time:24574ms step_avg:96.75ms -step:255/1695 train_time:24668ms step_avg:96.74ms -step:256/1695 train_time:24760ms step_avg:96.72ms -step:257/1695 train_time:24854ms step_avg:96.71ms -step:258/1695 train_time:24948ms step_avg:96.70ms -step:259/1695 train_time:25041ms step_avg:96.68ms -step:260/1695 train_time:25134ms step_avg:96.67ms -step:261/1695 train_time:25228ms step_avg:96.66ms -step:262/1695 train_time:25322ms step_avg:96.65ms -step:263/1695 train_time:25418ms step_avg:96.65ms -step:264/1695 train_time:25514ms step_avg:96.64ms -step:265/1695 train_time:25609ms step_avg:96.64ms -step:266/1695 train_time:25703ms step_avg:96.63ms -step:267/1695 train_time:25796ms step_avg:96.62ms -step:268/1695 train_time:25890ms step_avg:96.60ms -step:269/1695 train_time:25983ms step_avg:96.59ms -step:270/1695 train_time:26076ms step_avg:96.58ms -step:271/1695 train_time:26169ms step_avg:96.57ms -step:272/1695 train_time:26264ms step_avg:96.56ms -step:273/1695 train_time:26357ms step_avg:96.55ms -step:274/1695 train_time:26453ms step_avg:96.54ms -step:275/1695 train_time:26549ms step_avg:96.54ms -step:276/1695 train_time:26643ms step_avg:96.53ms -step:277/1695 train_time:26737ms step_avg:96.52ms -step:278/1695 train_time:26831ms step_avg:96.51ms -step:279/1695 train_time:26924ms step_avg:96.50ms -step:280/1695 train_time:27017ms step_avg:96.49ms -step:281/1695 train_time:27111ms step_avg:96.48ms -step:282/1695 train_time:27205ms step_avg:96.47ms -step:283/1695 train_time:27298ms step_avg:96.46ms -step:284/1695 train_time:27393ms step_avg:96.45ms -step:285/1695 train_time:27488ms step_avg:96.45ms -step:286/1695 train_time:27582ms step_avg:96.44ms -step:287/1695 train_time:27677ms step_avg:96.43ms -step:288/1695 train_time:27772ms step_avg:96.43ms -step:289/1695 train_time:27868ms step_avg:96.43ms -step:290/1695 train_time:27961ms step_avg:96.42ms -step:291/1695 train_time:28055ms step_avg:96.41ms -step:292/1695 train_time:28149ms step_avg:96.40ms -step:293/1695 train_time:28242ms step_avg:96.39ms -step:294/1695 train_time:28336ms step_avg:96.38ms -step:295/1695 train_time:28431ms step_avg:96.38ms -step:296/1695 train_time:28526ms step_avg:96.37ms -step:297/1695 train_time:28619ms step_avg:96.36ms -step:298/1695 train_time:28713ms step_avg:96.35ms -step:299/1695 train_time:28809ms step_avg:96.35ms -step:300/1695 train_time:28903ms step_avg:96.34ms -step:301/1695 train_time:28997ms step_avg:96.33ms -step:302/1695 train_time:29091ms step_avg:96.33ms -step:303/1695 train_time:29186ms step_avg:96.32ms -step:304/1695 train_time:29279ms step_avg:96.31ms -step:305/1695 train_time:29374ms step_avg:96.31ms -step:306/1695 train_time:29469ms step_avg:96.31ms -step:307/1695 train_time:29565ms step_avg:96.30ms -step:308/1695 train_time:29659ms step_avg:96.30ms -step:309/1695 train_time:29753ms step_avg:96.29ms -step:310/1695 train_time:29848ms step_avg:96.29ms -step:311/1695 train_time:29942ms step_avg:96.28ms -step:312/1695 train_time:30036ms step_avg:96.27ms -step:313/1695 train_time:30131ms step_avg:96.26ms -step:314/1695 train_time:30224ms step_avg:96.25ms -step:315/1695 train_time:30317ms step_avg:96.25ms -step:316/1695 train_time:30412ms step_avg:96.24ms -step:317/1695 train_time:30508ms step_avg:96.24ms -step:318/1695 train_time:30602ms step_avg:96.23ms -step:319/1695 train_time:30696ms step_avg:96.23ms -step:320/1695 train_time:30790ms step_avg:96.22ms -step:321/1695 train_time:30884ms step_avg:96.21ms -step:322/1695 train_time:30978ms step_avg:96.21ms -step:323/1695 train_time:31073ms step_avg:96.20ms -step:324/1695 train_time:31168ms step_avg:96.20ms -step:325/1695 train_time:31261ms step_avg:96.19ms -step:326/1695 train_time:31354ms step_avg:96.18ms -step:327/1695 train_time:31448ms step_avg:96.17ms -step:328/1695 train_time:31543ms step_avg:96.17ms -step:329/1695 train_time:31637ms step_avg:96.16ms -step:330/1695 train_time:31733ms step_avg:96.16ms -step:331/1695 train_time:31828ms step_avg:96.16ms -step:332/1695 train_time:31921ms step_avg:96.15ms -step:333/1695 train_time:32015ms step_avg:96.14ms -step:334/1695 train_time:32109ms step_avg:96.14ms -step:335/1695 train_time:32203ms step_avg:96.13ms -step:336/1695 train_time:32296ms step_avg:96.12ms -step:337/1695 train_time:32390ms step_avg:96.11ms -step:338/1695 train_time:32483ms step_avg:96.10ms -step:339/1695 train_time:32577ms step_avg:96.10ms -step:340/1695 train_time:32672ms step_avg:96.09ms -step:341/1695 train_time:32767ms step_avg:96.09ms -step:342/1695 train_time:32861ms step_avg:96.08ms -step:343/1695 train_time:32956ms step_avg:96.08ms -step:344/1695 train_time:33050ms step_avg:96.08ms -step:345/1695 train_time:33388ms step_avg:96.78ms -step:346/1695 train_time:33462ms step_avg:96.71ms -step:347/1695 train_time:33554ms step_avg:96.70ms -step:348/1695 train_time:33647ms step_avg:96.69ms -step:349/1695 train_time:33740ms step_avg:96.68ms -step:350/1695 train_time:33833ms step_avg:96.67ms -step:351/1695 train_time:33927ms step_avg:96.66ms -step:352/1695 train_time:34019ms step_avg:96.65ms -step:353/1695 train_time:34112ms step_avg:96.64ms -step:354/1695 train_time:34205ms step_avg:96.62ms -step:355/1695 train_time:34300ms step_avg:96.62ms -step:356/1695 train_time:34397ms step_avg:96.62ms -step:357/1695 train_time:34494ms step_avg:96.62ms -step:358/1695 train_time:34589ms step_avg:96.62ms -step:359/1695 train_time:34682ms step_avg:96.61ms -step:360/1695 train_time:34775ms step_avg:96.60ms -step:361/1695 train_time:34868ms step_avg:96.59ms -step:362/1695 train_time:34961ms step_avg:96.58ms -step:363/1695 train_time:35055ms step_avg:96.57ms -step:364/1695 train_time:35148ms step_avg:96.56ms -step:365/1695 train_time:35243ms step_avg:96.56ms -step:366/1695 train_time:35338ms step_avg:96.55ms -step:367/1695 train_time:35434ms step_avg:96.55ms -step:368/1695 train_time:35530ms step_avg:96.55ms -step:369/1695 train_time:35624ms step_avg:96.54ms -step:370/1695 train_time:35717ms step_avg:96.53ms -step:371/1695 train_time:35811ms step_avg:96.53ms -step:372/1695 train_time:35905ms step_avg:96.52ms -step:373/1695 train_time:35998ms step_avg:96.51ms -step:374/1695 train_time:36092ms step_avg:96.50ms -step:375/1695 train_time:36187ms step_avg:96.50ms -step:375/1695 val_loss:3.8237 train_time:36278ms step_avg:96.74ms -step:376/1695 train_time:36303ms step_avg:96.55ms -step:377/1695 train_time:36381ms step_avg:96.50ms -step:378/1695 train_time:36478ms step_avg:96.50ms -step:379/1695 train_time:36573ms step_avg:96.50ms -step:380/1695 train_time:36667ms step_avg:96.49ms -step:381/1695 train_time:36761ms step_avg:96.48ms -step:382/1695 train_time:36854ms step_avg:96.48ms -step:383/1695 train_time:36948ms step_avg:96.47ms -step:384/1695 train_time:37040ms step_avg:96.46ms -step:385/1695 train_time:37133ms step_avg:96.45ms -step:386/1695 train_time:37227ms step_avg:96.44ms -step:387/1695 train_time:37323ms step_avg:96.44ms -step:388/1695 train_time:37419ms step_avg:96.44ms -step:389/1695 train_time:37514ms step_avg:96.44ms -step:390/1695 train_time:37609ms step_avg:96.43ms -step:391/1695 train_time:37703ms step_avg:96.43ms -step:392/1695 train_time:37795ms step_avg:96.42ms -step:393/1695 train_time:37890ms step_avg:96.41ms -step:394/1695 train_time:37982ms step_avg:96.40ms -step:395/1695 train_time:38075ms step_avg:96.39ms -step:396/1695 train_time:38169ms step_avg:96.39ms -step:397/1695 train_time:38263ms step_avg:96.38ms -step:398/1695 train_time:38358ms step_avg:96.38ms -step:399/1695 train_time:38453ms step_avg:96.37ms -step:400/1695 train_time:38548ms step_avg:96.37ms -step:401/1695 train_time:38642ms step_avg:96.36ms -step:402/1695 train_time:38736ms step_avg:96.36ms -step:403/1695 train_time:38830ms step_avg:96.35ms -step:404/1695 train_time:38924ms step_avg:96.35ms -step:405/1695 train_time:39017ms step_avg:96.34ms -step:406/1695 train_time:39111ms step_avg:96.33ms -step:407/1695 train_time:39205ms step_avg:96.33ms -step:408/1695 train_time:39298ms step_avg:96.32ms -step:409/1695 train_time:39394ms step_avg:96.32ms -step:410/1695 train_time:39489ms step_avg:96.31ms -step:411/1695 train_time:39583ms step_avg:96.31ms -step:412/1695 train_time:39676ms step_avg:96.30ms -step:413/1695 train_time:39771ms step_avg:96.30ms -step:414/1695 train_time:39865ms step_avg:96.29ms -step:415/1695 train_time:39959ms step_avg:96.29ms -step:416/1695 train_time:40053ms step_avg:96.28ms -step:417/1695 train_time:40148ms step_avg:96.28ms -step:418/1695 train_time:40241ms step_avg:96.27ms -step:419/1695 train_time:40335ms step_avg:96.27ms -step:420/1695 train_time:40430ms step_avg:96.26ms -step:421/1695 train_time:40525ms step_avg:96.26ms -step:422/1695 train_time:40618ms step_avg:96.25ms -step:423/1695 train_time:40712ms step_avg:96.25ms -step:424/1695 train_time:40806ms step_avg:96.24ms -step:425/1695 train_time:40900ms step_avg:96.23ms -step:426/1695 train_time:40994ms step_avg:96.23ms -step:427/1695 train_time:41088ms step_avg:96.23ms -step:428/1695 train_time:41181ms step_avg:96.22ms -step:429/1695 train_time:41275ms step_avg:96.21ms -step:430/1695 train_time:41370ms step_avg:96.21ms -step:431/1695 train_time:41465ms step_avg:96.21ms -step:432/1695 train_time:41559ms step_avg:96.20ms -step:433/1695 train_time:41654ms step_avg:96.20ms -step:434/1695 train_time:41748ms step_avg:96.19ms -step:435/1695 train_time:41841ms step_avg:96.19ms -step:436/1695 train_time:41934ms step_avg:96.18ms -step:437/1695 train_time:42028ms step_avg:96.17ms -step:438/1695 train_time:42121ms step_avg:96.17ms -step:439/1695 train_time:42215ms step_avg:96.16ms -step:440/1695 train_time:42309ms step_avg:96.16ms -step:441/1695 train_time:42403ms step_avg:96.15ms -step:442/1695 train_time:42497ms step_avg:96.15ms -step:443/1695 train_time:42592ms step_avg:96.15ms -step:444/1695 train_time:42687ms step_avg:96.14ms -step:445/1695 train_time:42781ms step_avg:96.14ms -step:446/1695 train_time:42874ms step_avg:96.13ms -step:447/1695 train_time:42968ms step_avg:96.12ms -step:448/1695 train_time:43061ms step_avg:96.12ms -step:449/1695 train_time:43155ms step_avg:96.11ms -step:450/1695 train_time:43249ms step_avg:96.11ms -step:451/1695 train_time:43343ms step_avg:96.10ms -step:452/1695 train_time:43437ms step_avg:96.10ms -step:453/1695 train_time:43531ms step_avg:96.09ms -step:454/1695 train_time:43625ms step_avg:96.09ms -step:455/1695 train_time:43719ms step_avg:96.08ms -step:456/1695 train_time:43813ms step_avg:96.08ms -step:457/1695 train_time:43907ms step_avg:96.08ms -step:458/1695 train_time:44000ms step_avg:96.07ms -step:459/1695 train_time:44095ms step_avg:96.07ms -step:460/1695 train_time:44189ms step_avg:96.06ms -step:461/1695 train_time:44283ms step_avg:96.06ms -step:462/1695 train_time:44377ms step_avg:96.05ms -step:463/1695 train_time:44471ms step_avg:96.05ms -step:464/1695 train_time:44566ms step_avg:96.05ms -step:465/1695 train_time:44660ms step_avg:96.04ms -step:466/1695 train_time:44754ms step_avg:96.04ms -step:467/1695 train_time:44849ms step_avg:96.04ms -step:468/1695 train_time:44942ms step_avg:96.03ms -step:469/1695 train_time:45036ms step_avg:96.03ms -step:470/1695 train_time:45131ms step_avg:96.02ms -step:471/1695 train_time:45224ms step_avg:96.02ms -step:472/1695 train_time:45318ms step_avg:96.01ms -step:473/1695 train_time:45413ms step_avg:96.01ms -step:474/1695 train_time:45507ms step_avg:96.01ms -step:475/1695 train_time:45601ms step_avg:96.00ms -step:476/1695 train_time:45695ms step_avg:96.00ms -step:477/1695 train_time:45789ms step_avg:95.99ms -step:478/1695 train_time:45882ms step_avg:95.99ms -step:479/1695 train_time:45976ms step_avg:95.98ms -step:480/1695 train_time:46071ms step_avg:95.98ms -step:481/1695 train_time:46165ms step_avg:95.98ms -step:482/1695 train_time:46259ms step_avg:95.97ms -step:483/1695 train_time:46353ms step_avg:95.97ms -step:484/1695 train_time:46449ms step_avg:95.97ms -step:485/1695 train_time:46542ms step_avg:95.96ms -step:486/1695 train_time:46636ms step_avg:95.96ms -step:487/1695 train_time:46731ms step_avg:95.96ms -step:488/1695 train_time:46825ms step_avg:95.95ms -step:489/1695 train_time:46918ms step_avg:95.95ms -step:490/1695 train_time:47013ms step_avg:95.94ms -step:491/1695 train_time:47107ms step_avg:95.94ms -step:492/1695 train_time:47200ms step_avg:95.94ms -step:493/1695 train_time:47295ms step_avg:95.93ms -step:494/1695 train_time:47390ms step_avg:95.93ms -step:495/1695 train_time:47484ms step_avg:95.93ms -step:496/1695 train_time:47577ms step_avg:95.92ms -step:497/1695 train_time:47672ms step_avg:95.92ms -step:498/1695 train_time:47766ms step_avg:95.92ms -step:499/1695 train_time:47860ms step_avg:95.91ms -step:500/1695 train_time:47954ms step_avg:95.91ms -step:500/1695 val_loss:3.7206 train_time:48046ms step_avg:96.09ms -step:501/1695 train_time:48071ms step_avg:95.95ms -step:502/1695 train_time:48149ms step_avg:95.91ms -step:503/1695 train_time:48247ms step_avg:95.92ms -step:504/1695 train_time:48342ms step_avg:95.92ms -step:505/1695 train_time:48436ms step_avg:95.91ms -step:506/1695 train_time:48529ms step_avg:95.91ms -step:507/1695 train_time:48622ms step_avg:95.90ms -step:508/1695 train_time:48715ms step_avg:95.90ms -step:509/1695 train_time:48808ms step_avg:95.89ms -step:510/1695 train_time:48901ms step_avg:95.88ms -step:511/1695 train_time:48994ms step_avg:95.88ms -step:512/1695 train_time:49090ms step_avg:95.88ms -step:513/1695 train_time:49185ms step_avg:95.88ms -step:514/1695 train_time:49281ms step_avg:95.88ms -step:515/1695 train_time:49377ms step_avg:95.88ms -step:516/1695 train_time:49471ms step_avg:95.87ms -step:517/1695 train_time:49564ms step_avg:95.87ms -step:518/1695 train_time:49657ms step_avg:95.86ms -step:519/1695 train_time:49991ms step_avg:96.32ms -step:520/1695 train_time:50182ms step_avg:96.50ms -step:521/1695 train_time:50274ms step_avg:96.50ms -step:522/1695 train_time:50366ms step_avg:96.49ms -step:523/1695 train_time:50458ms step_avg:96.48ms -step:524/1695 train_time:50551ms step_avg:96.47ms -step:525/1695 train_time:50644ms step_avg:96.46ms -step:526/1695 train_time:50737ms step_avg:96.46ms -step:527/1695 train_time:50829ms step_avg:96.45ms -step:528/1695 train_time:50922ms step_avg:96.44ms -step:529/1695 train_time:51018ms step_avg:96.44ms -step:530/1695 train_time:51114ms step_avg:96.44ms -step:531/1695 train_time:51210ms step_avg:96.44ms -step:532/1695 train_time:51305ms step_avg:96.44ms -step:533/1695 train_time:51399ms step_avg:96.43ms -step:534/1695 train_time:51493ms step_avg:96.43ms -step:535/1695 train_time:51586ms step_avg:96.42ms -step:536/1695 train_time:51679ms step_avg:96.42ms -step:537/1695 train_time:51772ms step_avg:96.41ms -step:538/1695 train_time:51865ms step_avg:96.40ms -step:539/1695 train_time:51958ms step_avg:96.40ms -step:540/1695 train_time:52052ms step_avg:96.39ms -step:541/1695 train_time:52147ms step_avg:96.39ms -step:542/1695 train_time:52241ms step_avg:96.39ms -step:543/1695 train_time:52336ms step_avg:96.38ms -step:544/1695 train_time:52429ms step_avg:96.38ms -step:545/1695 train_time:52523ms step_avg:96.37ms -step:546/1695 train_time:52617ms step_avg:96.37ms -step:547/1695 train_time:52710ms step_avg:96.36ms -step:548/1695 train_time:52803ms step_avg:96.36ms -step:549/1695 train_time:52898ms step_avg:96.35ms -step:550/1695 train_time:52991ms step_avg:96.35ms -step:551/1695 train_time:53086ms step_avg:96.34ms -step:552/1695 train_time:53182ms step_avg:96.34ms -step:553/1695 train_time:53276ms step_avg:96.34ms -step:554/1695 train_time:53370ms step_avg:96.34ms -step:555/1695 train_time:53464ms step_avg:96.33ms -step:556/1695 train_time:53558ms step_avg:96.33ms -step:557/1695 train_time:53652ms step_avg:96.32ms -step:558/1695 train_time:53745ms step_avg:96.32ms -step:559/1695 train_time:53839ms step_avg:96.31ms -step:560/1695 train_time:53933ms step_avg:96.31ms -step:561/1695 train_time:54027ms step_avg:96.30ms -step:562/1695 train_time:54121ms step_avg:96.30ms -step:563/1695 train_time:54216ms step_avg:96.30ms -step:564/1695 train_time:54311ms step_avg:96.30ms -step:565/1695 train_time:54405ms step_avg:96.29ms -step:566/1695 train_time:54499ms step_avg:96.29ms -step:567/1695 train_time:54594ms step_avg:96.29ms -step:568/1695 train_time:54690ms step_avg:96.28ms -step:569/1695 train_time:54785ms step_avg:96.28ms -step:570/1695 train_time:54881ms step_avg:96.28ms -step:571/1695 train_time:54978ms step_avg:96.28ms -step:572/1695 train_time:55074ms step_avg:96.28ms -step:573/1695 train_time:55169ms step_avg:96.28ms -step:574/1695 train_time:55265ms step_avg:96.28ms -step:575/1695 train_time:55363ms step_avg:96.28ms -step:576/1695 train_time:55460ms step_avg:96.28ms -step:577/1695 train_time:55557ms step_avg:96.29ms -step:578/1695 train_time:55653ms step_avg:96.29ms -step:579/1695 train_time:55748ms step_avg:96.28ms -step:580/1695 train_time:55845ms step_avg:96.28ms -step:581/1695 train_time:55940ms step_avg:96.28ms -step:582/1695 train_time:56037ms step_avg:96.28ms -step:583/1695 train_time:56132ms step_avg:96.28ms -step:584/1695 train_time:56228ms step_avg:96.28ms -step:585/1695 train_time:56324ms step_avg:96.28ms -step:586/1695 train_time:56421ms step_avg:96.28ms -step:587/1695 train_time:56519ms step_avg:96.28ms -step:588/1695 train_time:56617ms step_avg:96.29ms -step:589/1695 train_time:56712ms step_avg:96.29ms -step:590/1695 train_time:56807ms step_avg:96.28ms -step:591/1695 train_time:56903ms step_avg:96.28ms -step:592/1695 train_time:57000ms step_avg:96.28ms -step:593/1695 train_time:57096ms step_avg:96.28ms -step:594/1695 train_time:57192ms step_avg:96.28ms -step:595/1695 train_time:57288ms step_avg:96.28ms -step:596/1695 train_time:57384ms step_avg:96.28ms -step:597/1695 train_time:57481ms step_avg:96.28ms -step:598/1695 train_time:57577ms step_avg:96.28ms -step:599/1695 train_time:57673ms step_avg:96.28ms -step:600/1695 train_time:57769ms step_avg:96.28ms -step:601/1695 train_time:57864ms step_avg:96.28ms -step:602/1695 train_time:57960ms step_avg:96.28ms -step:603/1695 train_time:58057ms step_avg:96.28ms -step:604/1695 train_time:58153ms step_avg:96.28ms -step:605/1695 train_time:58248ms step_avg:96.28ms -step:606/1695 train_time:58344ms step_avg:96.28ms -step:607/1695 train_time:58440ms step_avg:96.28ms -step:608/1695 train_time:58537ms step_avg:96.28ms -step:609/1695 train_time:58633ms step_avg:96.28ms -step:610/1695 train_time:58728ms step_avg:96.28ms -step:611/1695 train_time:58824ms step_avg:96.27ms -step:612/1695 train_time:58920ms step_avg:96.27ms -step:613/1695 train_time:59016ms step_avg:96.27ms -step:614/1695 train_time:59111ms step_avg:96.27ms -step:615/1695 train_time:59208ms step_avg:96.27ms -step:616/1695 train_time:59305ms step_avg:96.27ms -step:617/1695 train_time:59401ms step_avg:96.27ms -step:618/1695 train_time:59499ms step_avg:96.28ms -step:619/1695 train_time:59596ms step_avg:96.28ms -step:620/1695 train_time:59692ms step_avg:96.28ms -step:621/1695 train_time:59789ms step_avg:96.28ms -step:622/1695 train_time:59885ms step_avg:96.28ms -step:623/1695 train_time:59982ms step_avg:96.28ms -step:624/1695 train_time:60079ms step_avg:96.28ms -step:625/1695 train_time:60176ms step_avg:96.28ms -step:625/1695 val_loss:3.6228 train_time:60270ms step_avg:96.43ms -step:626/1695 train_time:60294ms step_avg:96.32ms -step:627/1695 train_time:60378ms step_avg:96.30ms -step:628/1695 train_time:60473ms step_avg:96.30ms -step:629/1695 train_time:60570ms step_avg:96.30ms -step:630/1695 train_time:60665ms step_avg:96.29ms -step:631/1695 train_time:60761ms step_avg:96.29ms -step:632/1695 train_time:60855ms step_avg:96.29ms -step:633/1695 train_time:60949ms step_avg:96.29ms -step:634/1695 train_time:61044ms step_avg:96.28ms -step:635/1695 train_time:61139ms step_avg:96.28ms -step:636/1695 train_time:61237ms step_avg:96.28ms -step:637/1695 train_time:61335ms step_avg:96.29ms -step:638/1695 train_time:61432ms step_avg:96.29ms -step:639/1695 train_time:61530ms step_avg:96.29ms -step:640/1695 train_time:61627ms step_avg:96.29ms -step:641/1695 train_time:61723ms step_avg:96.29ms -step:642/1695 train_time:61818ms step_avg:96.29ms -step:643/1695 train_time:61912ms step_avg:96.29ms -step:644/1695 train_time:62007ms step_avg:96.28ms -step:645/1695 train_time:62103ms step_avg:96.28ms -step:646/1695 train_time:62199ms step_avg:96.28ms -step:647/1695 train_time:62297ms step_avg:96.29ms -step:648/1695 train_time:62394ms step_avg:96.29ms -step:649/1695 train_time:62491ms step_avg:96.29ms -step:650/1695 train_time:62588ms step_avg:96.29ms -step:651/1695 train_time:62686ms step_avg:96.29ms -step:652/1695 train_time:62783ms step_avg:96.29ms -step:653/1695 train_time:62879ms step_avg:96.29ms -step:654/1695 train_time:62973ms step_avg:96.29ms -step:655/1695 train_time:63068ms step_avg:96.29ms -step:656/1695 train_time:63166ms step_avg:96.29ms -step:657/1695 train_time:63264ms step_avg:96.29ms -step:658/1695 train_time:63362ms step_avg:96.29ms -step:659/1695 train_time:63459ms step_avg:96.30ms -step:660/1695 train_time:63555ms step_avg:96.30ms -step:661/1695 train_time:63652ms step_avg:96.30ms -step:662/1695 train_time:63747ms step_avg:96.30ms -step:663/1695 train_time:63844ms step_avg:96.30ms -step:664/1695 train_time:63939ms step_avg:96.29ms -step:665/1695 train_time:64034ms step_avg:96.29ms -step:666/1695 train_time:64130ms step_avg:96.29ms -step:667/1695 train_time:64227ms step_avg:96.29ms -step:668/1695 train_time:64325ms step_avg:96.29ms -step:669/1695 train_time:64422ms step_avg:96.30ms -step:670/1695 train_time:64518ms step_avg:96.30ms -step:671/1695 train_time:64614ms step_avg:96.30ms -step:672/1695 train_time:64710ms step_avg:96.29ms -step:673/1695 train_time:64806ms step_avg:96.29ms -step:674/1695 train_time:64902ms step_avg:96.29ms -step:675/1695 train_time:64998ms step_avg:96.29ms -step:676/1695 train_time:65093ms step_avg:96.29ms -step:677/1695 train_time:65190ms step_avg:96.29ms -step:678/1695 train_time:65287ms step_avg:96.29ms -step:679/1695 train_time:65384ms step_avg:96.30ms -step:680/1695 train_time:65482ms step_avg:96.30ms -step:681/1695 train_time:65578ms step_avg:96.30ms -step:682/1695 train_time:65674ms step_avg:96.30ms -step:683/1695 train_time:65769ms step_avg:96.29ms -step:684/1695 train_time:65866ms step_avg:96.30ms -step:685/1695 train_time:65962ms step_avg:96.29ms -step:686/1695 train_time:66058ms step_avg:96.29ms -step:687/1695 train_time:66153ms step_avg:96.29ms -step:688/1695 train_time:66249ms step_avg:96.29ms -step:689/1695 train_time:66345ms step_avg:96.29ms -step:690/1695 train_time:66441ms step_avg:96.29ms -step:691/1695 train_time:66884ms step_avg:96.79ms -step:692/1695 train_time:66970ms step_avg:96.78ms -step:693/1695 train_time:67064ms step_avg:96.77ms -step:694/1695 train_time:67159ms step_avg:96.77ms -step:695/1695 train_time:67254ms step_avg:96.77ms -step:696/1695 train_time:67348ms step_avg:96.77ms -step:697/1695 train_time:67443ms step_avg:96.76ms -step:698/1695 train_time:67538ms step_avg:96.76ms -step:699/1695 train_time:67633ms step_avg:96.76ms -step:700/1695 train_time:67727ms step_avg:96.75ms -step:701/1695 train_time:67826ms step_avg:96.76ms -step:702/1695 train_time:67927ms step_avg:96.76ms -step:703/1695 train_time:68025ms step_avg:96.76ms -step:704/1695 train_time:68122ms step_avg:96.76ms -step:705/1695 train_time:68218ms step_avg:96.76ms -step:706/1695 train_time:68313ms step_avg:96.76ms -step:707/1695 train_time:68408ms step_avg:96.76ms -step:708/1695 train_time:68503ms step_avg:96.76ms -step:709/1695 train_time:68598ms step_avg:96.75ms -step:710/1695 train_time:68693ms step_avg:96.75ms -step:711/1695 train_time:68789ms step_avg:96.75ms -step:712/1695 train_time:68885ms step_avg:96.75ms -step:713/1695 train_time:68982ms step_avg:96.75ms -step:714/1695 train_time:69079ms step_avg:96.75ms -step:715/1695 train_time:69175ms step_avg:96.75ms -step:716/1695 train_time:69271ms step_avg:96.75ms -step:717/1695 train_time:69366ms step_avg:96.74ms -step:718/1695 train_time:69461ms step_avg:96.74ms -step:719/1695 train_time:69556ms step_avg:96.74ms -step:720/1695 train_time:69652ms step_avg:96.74ms -step:721/1695 train_time:69748ms step_avg:96.74ms -step:722/1695 train_time:69845ms step_avg:96.74ms -step:723/1695 train_time:69942ms step_avg:96.74ms -step:724/1695 train_time:70038ms step_avg:96.74ms -step:725/1695 train_time:70134ms step_avg:96.74ms -step:726/1695 train_time:70230ms step_avg:96.74ms -step:727/1695 train_time:70328ms step_avg:96.74ms -step:728/1695 train_time:70425ms step_avg:96.74ms -step:729/1695 train_time:70520ms step_avg:96.74ms -step:730/1695 train_time:70615ms step_avg:96.73ms -step:731/1695 train_time:70710ms step_avg:96.73ms -step:732/1695 train_time:70806ms step_avg:96.73ms -step:733/1695 train_time:70904ms step_avg:96.73ms -step:734/1695 train_time:71000ms step_avg:96.73ms -step:735/1695 train_time:71097ms step_avg:96.73ms -step:736/1695 train_time:71193ms step_avg:96.73ms -step:737/1695 train_time:71289ms step_avg:96.73ms -step:738/1695 train_time:71386ms step_avg:96.73ms -step:739/1695 train_time:71483ms step_avg:96.73ms -step:740/1695 train_time:71579ms step_avg:96.73ms -step:741/1695 train_time:71674ms step_avg:96.73ms -step:742/1695 train_time:71769ms step_avg:96.72ms -step:743/1695 train_time:71866ms step_avg:96.72ms -step:744/1695 train_time:71963ms step_avg:96.72ms -step:745/1695 train_time:72059ms step_avg:96.72ms -step:746/1695 train_time:72155ms step_avg:96.72ms -step:747/1695 train_time:72251ms step_avg:96.72ms -step:748/1695 train_time:72348ms step_avg:96.72ms -step:749/1695 train_time:72444ms step_avg:96.72ms -step:750/1695 train_time:72540ms step_avg:96.72ms -step:750/1695 val_loss:3.5691 train_time:72633ms step_avg:96.84ms -step:751/1695 train_time:72658ms step_avg:96.75ms -step:752/1695 train_time:72740ms step_avg:96.73ms -step:753/1695 train_time:72841ms step_avg:96.73ms -step:754/1695 train_time:72937ms step_avg:96.73ms -step:755/1695 train_time:73032ms step_avg:96.73ms -step:756/1695 train_time:73127ms step_avg:96.73ms -step:757/1695 train_time:73221ms step_avg:96.73ms -step:758/1695 train_time:73316ms step_avg:96.72ms -step:759/1695 train_time:73411ms step_avg:96.72ms -step:760/1695 train_time:73506ms step_avg:96.72ms -step:761/1695 train_time:73604ms step_avg:96.72ms -step:762/1695 train_time:73701ms step_avg:96.72ms -step:763/1695 train_time:73800ms step_avg:96.72ms -step:764/1695 train_time:73896ms step_avg:96.72ms -step:765/1695 train_time:73992ms step_avg:96.72ms -step:766/1695 train_time:74088ms step_avg:96.72ms -step:767/1695 train_time:74183ms step_avg:96.72ms -step:768/1695 train_time:74278ms step_avg:96.72ms -step:769/1695 train_time:74373ms step_avg:96.71ms -step:770/1695 train_time:74470ms step_avg:96.71ms -step:771/1695 train_time:74566ms step_avg:96.71ms -step:772/1695 train_time:74663ms step_avg:96.71ms -step:773/1695 train_time:74761ms step_avg:96.71ms -step:774/1695 train_time:74857ms step_avg:96.72ms -step:775/1695 train_time:74954ms step_avg:96.71ms -step:776/1695 train_time:75051ms step_avg:96.72ms -step:777/1695 train_time:75147ms step_avg:96.71ms -step:778/1695 train_time:75242ms step_avg:96.71ms -step:779/1695 train_time:75338ms step_avg:96.71ms -step:780/1695 train_time:75434ms step_avg:96.71ms -step:781/1695 train_time:75531ms step_avg:96.71ms -step:782/1695 train_time:75627ms step_avg:96.71ms -step:783/1695 train_time:75724ms step_avg:96.71ms -step:784/1695 train_time:75820ms step_avg:96.71ms -step:785/1695 train_time:75916ms step_avg:96.71ms -step:786/1695 train_time:76013ms step_avg:96.71ms -step:787/1695 train_time:76109ms step_avg:96.71ms -step:788/1695 train_time:76206ms step_avg:96.71ms -step:789/1695 train_time:76301ms step_avg:96.71ms -step:790/1695 train_time:76397ms step_avg:96.71ms -step:791/1695 train_time:76493ms step_avg:96.70ms -step:792/1695 train_time:76591ms step_avg:96.71ms -step:793/1695 train_time:76689ms step_avg:96.71ms -step:794/1695 train_time:76785ms step_avg:96.71ms -step:795/1695 train_time:76881ms step_avg:96.71ms -step:796/1695 train_time:76977ms step_avg:96.71ms -step:797/1695 train_time:77073ms step_avg:96.70ms -step:798/1695 train_time:77169ms step_avg:96.70ms -step:799/1695 train_time:77264ms step_avg:96.70ms -step:800/1695 train_time:77359ms step_avg:96.70ms -step:801/1695 train_time:77455ms step_avg:96.70ms -step:802/1695 train_time:77551ms step_avg:96.70ms -step:803/1695 train_time:77647ms step_avg:96.70ms -step:804/1695 train_time:77743ms step_avg:96.69ms -step:805/1695 train_time:77839ms step_avg:96.69ms -step:806/1695 train_time:77936ms step_avg:96.69ms -step:807/1695 train_time:78033ms step_avg:96.69ms -step:808/1695 train_time:78130ms step_avg:96.69ms -step:809/1695 train_time:78226ms step_avg:96.69ms -step:810/1695 train_time:78321ms step_avg:96.69ms -step:811/1695 train_time:78417ms step_avg:96.69ms -step:812/1695 train_time:78513ms step_avg:96.69ms -step:813/1695 train_time:78609ms step_avg:96.69ms -step:814/1695 train_time:78706ms step_avg:96.69ms -step:815/1695 train_time:78802ms step_avg:96.69ms -step:816/1695 train_time:78898ms step_avg:96.69ms -step:817/1695 train_time:78996ms step_avg:96.69ms -step:818/1695 train_time:79093ms step_avg:96.69ms -step:819/1695 train_time:79190ms step_avg:96.69ms -step:820/1695 train_time:79286ms step_avg:96.69ms -step:821/1695 train_time:79381ms step_avg:96.69ms -step:822/1695 train_time:79477ms step_avg:96.69ms -step:823/1695 train_time:79574ms step_avg:96.69ms -step:824/1695 train_time:79671ms step_avg:96.69ms -step:825/1695 train_time:79767ms step_avg:96.69ms -step:826/1695 train_time:79863ms step_avg:96.69ms -step:827/1695 train_time:79959ms step_avg:96.69ms -step:828/1695 train_time:80056ms step_avg:96.69ms -step:829/1695 train_time:80153ms step_avg:96.69ms -step:830/1695 train_time:80249ms step_avg:96.69ms -step:831/1695 train_time:80345ms step_avg:96.68ms -step:832/1695 train_time:80440ms step_avg:96.68ms -step:833/1695 train_time:80536ms step_avg:96.68ms -step:834/1695 train_time:80633ms step_avg:96.68ms -step:835/1695 train_time:80730ms step_avg:96.68ms -step:836/1695 train_time:80826ms step_avg:96.68ms -step:837/1695 train_time:80922ms step_avg:96.68ms -step:838/1695 train_time:81019ms step_avg:96.68ms -step:839/1695 train_time:81116ms step_avg:96.68ms -step:840/1695 train_time:81213ms step_avg:96.68ms -step:841/1695 train_time:81309ms step_avg:96.68ms -step:842/1695 train_time:81405ms step_avg:96.68ms -step:843/1695 train_time:81502ms step_avg:96.68ms -step:844/1695 train_time:81597ms step_avg:96.68ms -step:845/1695 train_time:81694ms step_avg:96.68ms -step:846/1695 train_time:81791ms step_avg:96.68ms -step:847/1695 train_time:81888ms step_avg:96.68ms -step:848/1695 train_time:81983ms step_avg:96.68ms -step:849/1695 train_time:82078ms step_avg:96.68ms -step:850/1695 train_time:82174ms step_avg:96.68ms -step:851/1695 train_time:82271ms step_avg:96.68ms -step:852/1695 train_time:82368ms step_avg:96.68ms -step:853/1695 train_time:82464ms step_avg:96.68ms -step:854/1695 train_time:82559ms step_avg:96.67ms -step:855/1695 train_time:82655ms step_avg:96.67ms -step:856/1695 train_time:82752ms step_avg:96.67ms -step:857/1695 train_time:82848ms step_avg:96.67ms -step:858/1695 train_time:82944ms step_avg:96.67ms -step:859/1695 train_time:83039ms step_avg:96.67ms -step:860/1695 train_time:83136ms step_avg:96.67ms -step:861/1695 train_time:83233ms step_avg:96.67ms -step:862/1695 train_time:83331ms step_avg:96.67ms -step:863/1695 train_time:83651ms step_avg:96.93ms -step:864/1695 train_time:83850ms step_avg:97.05ms -step:865/1695 train_time:83944ms step_avg:97.05ms -step:866/1695 train_time:84039ms step_avg:97.04ms -step:867/1695 train_time:84134ms step_avg:97.04ms -step:868/1695 train_time:84229ms step_avg:97.04ms -step:869/1695 train_time:84323ms step_avg:97.03ms -step:870/1695 train_time:84418ms step_avg:97.03ms -step:871/1695 train_time:84513ms step_avg:97.03ms -step:872/1695 train_time:84609ms step_avg:97.03ms -step:873/1695 train_time:84705ms step_avg:97.03ms -step:874/1695 train_time:84806ms step_avg:97.03ms -step:875/1695 train_time:84904ms step_avg:97.03ms -step:875/1695 val_loss:3.5252 train_time:84998ms step_avg:97.14ms -step:876/1695 train_time:85024ms step_avg:97.06ms -step:877/1695 train_time:85101ms step_avg:97.04ms -step:878/1695 train_time:85198ms step_avg:97.04ms -step:879/1695 train_time:85293ms step_avg:97.03ms -step:880/1695 train_time:85388ms step_avg:97.03ms -step:881/1695 train_time:85484ms step_avg:97.03ms -step:882/1695 train_time:85579ms step_avg:97.03ms -step:883/1695 train_time:85673ms step_avg:97.03ms -step:884/1695 train_time:85768ms step_avg:97.02ms -step:885/1695 train_time:85863ms step_avg:97.02ms -step:886/1695 train_time:85962ms step_avg:97.02ms -step:887/1695 train_time:86061ms step_avg:97.02ms -step:888/1695 train_time:86157ms step_avg:97.02ms -step:889/1695 train_time:86253ms step_avg:97.02ms -step:890/1695 train_time:86350ms step_avg:97.02ms -step:891/1695 train_time:86446ms step_avg:97.02ms -step:892/1695 train_time:86542ms step_avg:97.02ms -step:893/1695 train_time:86637ms step_avg:97.02ms -step:894/1695 train_time:86732ms step_avg:97.02ms -step:895/1695 train_time:86828ms step_avg:97.01ms -step:896/1695 train_time:86924ms step_avg:97.01ms -step:897/1695 train_time:87021ms step_avg:97.01ms -step:898/1695 train_time:87117ms step_avg:97.01ms -step:899/1695 train_time:87213ms step_avg:97.01ms -step:900/1695 train_time:87309ms step_avg:97.01ms -step:901/1695 train_time:87406ms step_avg:97.01ms -step:902/1695 train_time:87502ms step_avg:97.01ms -step:903/1695 train_time:87598ms step_avg:97.01ms -step:904/1695 train_time:87693ms step_avg:97.01ms -step:905/1695 train_time:87789ms step_avg:97.00ms -step:906/1695 train_time:87885ms step_avg:97.00ms -step:907/1695 train_time:87982ms step_avg:97.00ms -step:908/1695 train_time:88078ms step_avg:97.00ms -step:909/1695 train_time:88174ms step_avg:97.00ms -step:910/1695 train_time:88270ms step_avg:97.00ms -step:911/1695 train_time:88366ms step_avg:97.00ms -step:912/1695 train_time:88462ms step_avg:97.00ms -step:913/1695 train_time:88558ms step_avg:97.00ms -step:914/1695 train_time:88653ms step_avg:96.99ms -step:915/1695 train_time:88749ms step_avg:96.99ms -step:916/1695 train_time:88845ms step_avg:96.99ms -step:917/1695 train_time:88942ms step_avg:96.99ms -step:918/1695 train_time:89038ms step_avg:96.99ms -step:919/1695 train_time:89133ms step_avg:96.99ms -step:920/1695 train_time:89230ms step_avg:96.99ms -step:921/1695 train_time:89327ms step_avg:96.99ms -step:922/1695 train_time:89423ms step_avg:96.99ms -step:923/1695 train_time:89519ms step_avg:96.99ms -step:924/1695 train_time:89615ms step_avg:96.99ms -step:925/1695 train_time:89711ms step_avg:96.98ms -step:926/1695 train_time:89806ms step_avg:96.98ms -step:927/1695 train_time:89902ms step_avg:96.98ms -step:928/1695 train_time:89998ms step_avg:96.98ms -step:929/1695 train_time:90094ms step_avg:96.98ms -step:930/1695 train_time:90190ms step_avg:96.98ms -step:931/1695 train_time:90287ms step_avg:96.98ms -step:932/1695 train_time:90383ms step_avg:96.98ms -step:933/1695 train_time:90479ms step_avg:96.98ms -step:934/1695 train_time:90575ms step_avg:96.98ms -step:935/1695 train_time:90670ms step_avg:96.97ms -step:936/1695 train_time:90767ms step_avg:96.97ms -step:937/1695 train_time:90863ms step_avg:96.97ms -step:938/1695 train_time:90960ms step_avg:96.97ms -step:939/1695 train_time:91055ms step_avg:96.97ms -step:940/1695 train_time:91151ms step_avg:96.97ms -step:941/1695 train_time:91247ms step_avg:96.97ms -step:942/1695 train_time:91343ms step_avg:96.97ms -step:943/1695 train_time:91440ms step_avg:96.97ms -step:944/1695 train_time:91535ms step_avg:96.97ms -step:945/1695 train_time:91631ms step_avg:96.96ms -step:946/1695 train_time:91726ms step_avg:96.96ms -step:947/1695 train_time:91822ms step_avg:96.96ms -step:948/1695 train_time:91918ms step_avg:96.96ms -step:949/1695 train_time:92013ms step_avg:96.96ms -step:950/1695 train_time:92109ms step_avg:96.96ms -step:951/1695 train_time:92205ms step_avg:96.96ms -step:952/1695 train_time:92301ms step_avg:96.95ms -step:953/1695 train_time:92396ms step_avg:96.95ms -step:954/1695 train_time:92491ms step_avg:96.95ms -step:955/1695 train_time:92587ms step_avg:96.95ms -step:956/1695 train_time:92683ms step_avg:96.95ms -step:957/1695 train_time:92779ms step_avg:96.95ms -step:958/1695 train_time:92875ms step_avg:96.95ms -step:959/1695 train_time:92971ms step_avg:96.95ms -step:960/1695 train_time:93068ms step_avg:96.95ms -step:961/1695 train_time:93164ms step_avg:96.95ms -step:962/1695 train_time:93260ms step_avg:96.94ms -step:963/1695 train_time:93356ms step_avg:96.94ms -step:964/1695 train_time:93453ms step_avg:96.94ms -step:965/1695 train_time:93549ms step_avg:96.94ms -step:966/1695 train_time:93646ms step_avg:96.94ms -step:967/1695 train_time:93743ms step_avg:96.94ms -step:968/1695 train_time:93839ms step_avg:96.94ms -step:969/1695 train_time:93934ms step_avg:96.94ms -step:970/1695 train_time:94030ms step_avg:96.94ms -step:971/1695 train_time:94126ms step_avg:96.94ms -step:972/1695 train_time:94222ms step_avg:96.94ms -step:973/1695 train_time:94318ms step_avg:96.94ms -step:974/1695 train_time:94414ms step_avg:96.93ms -step:975/1695 train_time:94510ms step_avg:96.93ms -step:976/1695 train_time:94607ms step_avg:96.93ms -step:977/1695 train_time:94702ms step_avg:96.93ms -step:978/1695 train_time:94799ms step_avg:96.93ms -step:979/1695 train_time:94894ms step_avg:96.93ms -step:980/1695 train_time:94990ms step_avg:96.93ms -step:981/1695 train_time:95087ms step_avg:96.93ms -step:982/1695 train_time:95184ms step_avg:96.93ms -step:983/1695 train_time:95280ms step_avg:96.93ms -step:984/1695 train_time:95376ms step_avg:96.93ms -step:985/1695 train_time:95472ms step_avg:96.93ms -step:986/1695 train_time:95568ms step_avg:96.92ms -step:987/1695 train_time:95664ms step_avg:96.92ms -step:988/1695 train_time:95760ms step_avg:96.92ms -step:989/1695 train_time:95855ms step_avg:96.92ms -step:990/1695 train_time:95950ms step_avg:96.92ms -step:991/1695 train_time:96047ms step_avg:96.92ms -step:992/1695 train_time:96144ms step_avg:96.92ms -step:993/1695 train_time:96240ms step_avg:96.92ms -step:994/1695 train_time:96335ms step_avg:96.92ms -step:995/1695 train_time:96431ms step_avg:96.92ms -step:996/1695 train_time:96528ms step_avg:96.92ms -step:997/1695 train_time:96625ms step_avg:96.92ms -step:998/1695 train_time:96722ms step_avg:96.92ms -step:999/1695 train_time:96817ms step_avg:96.91ms -step:1000/1695 train_time:96912ms step_avg:96.91ms -step:1000/1695 val_loss:3.4845 train_time:97007ms step_avg:97.01ms -step:1001/1695 train_time:97031ms step_avg:96.93ms -step:1002/1695 train_time:97110ms step_avg:96.92ms -step:1003/1695 train_time:97207ms step_avg:96.92ms -step:1004/1695 train_time:97304ms step_avg:96.92ms -step:1005/1695 train_time:97399ms step_avg:96.91ms -step:1006/1695 train_time:97495ms step_avg:96.91ms -step:1007/1695 train_time:97589ms step_avg:96.91ms -step:1008/1695 train_time:97684ms step_avg:96.91ms -step:1009/1695 train_time:97779ms step_avg:96.91ms -step:1010/1695 train_time:97874ms step_avg:96.90ms -step:1011/1695 train_time:97971ms step_avg:96.90ms -step:1012/1695 train_time:98069ms step_avg:96.91ms -step:1013/1695 train_time:98165ms step_avg:96.91ms -step:1014/1695 train_time:98263ms step_avg:96.91ms -step:1015/1695 train_time:98359ms step_avg:96.91ms -step:1016/1695 train_time:98454ms step_avg:96.90ms -step:1017/1695 train_time:98549ms step_avg:96.90ms -step:1018/1695 train_time:98644ms step_avg:96.90ms -step:1019/1695 train_time:98740ms step_avg:96.90ms -step:1020/1695 train_time:98835ms step_avg:96.90ms -step:1021/1695 train_time:98930ms step_avg:96.89ms -step:1022/1695 train_time:99026ms step_avg:96.89ms -step:1023/1695 train_time:99123ms step_avg:96.89ms -step:1024/1695 train_time:99220ms step_avg:96.89ms -step:1025/1695 train_time:99317ms step_avg:96.89ms -step:1026/1695 train_time:99414ms step_avg:96.90ms -step:1027/1695 train_time:99510ms step_avg:96.89ms -step:1028/1695 train_time:99605ms step_avg:96.89ms -step:1029/1695 train_time:99700ms step_avg:96.89ms -step:1030/1695 train_time:99795ms step_avg:96.89ms -step:1031/1695 train_time:99891ms step_avg:96.89ms -step:1032/1695 train_time:99986ms step_avg:96.89ms -step:1033/1695 train_time:100082ms step_avg:96.88ms -step:1034/1695 train_time:100179ms step_avg:96.88ms -step:1035/1695 train_time:100274ms step_avg:96.88ms -step:1036/1695 train_time:100600ms step_avg:97.10ms -step:1037/1695 train_time:100773ms step_avg:97.18ms -step:1038/1695 train_time:100866ms step_avg:97.17ms -step:1039/1695 train_time:100961ms step_avg:97.17ms -step:1040/1695 train_time:101056ms step_avg:97.17ms -step:1041/1695 train_time:101151ms step_avg:97.17ms -step:1042/1695 train_time:101245ms step_avg:97.16ms -step:1043/1695 train_time:101341ms step_avg:97.16ms -step:1044/1695 train_time:101436ms step_avg:97.16ms -step:1045/1695 train_time:101530ms step_avg:97.16ms -step:1046/1695 train_time:101626ms step_avg:97.16ms -step:1047/1695 train_time:101728ms step_avg:97.16ms -step:1048/1695 train_time:101826ms step_avg:97.16ms -step:1049/1695 train_time:101923ms step_avg:97.16ms -step:1050/1695 train_time:102019ms step_avg:97.16ms -step:1051/1695 train_time:102115ms step_avg:97.16ms -step:1052/1695 train_time:102210ms step_avg:97.16ms -step:1053/1695 train_time:102304ms step_avg:97.16ms -step:1054/1695 train_time:102399ms step_avg:97.15ms -step:1055/1695 train_time:102494ms step_avg:97.15ms -step:1056/1695 train_time:102590ms step_avg:97.15ms -step:1057/1695 train_time:102688ms step_avg:97.15ms -step:1058/1695 train_time:102784ms step_avg:97.15ms -step:1059/1695 train_time:102881ms step_avg:97.15ms -step:1060/1695 train_time:102978ms step_avg:97.15ms -step:1061/1695 train_time:103075ms step_avg:97.15ms -step:1062/1695 train_time:103171ms step_avg:97.15ms -step:1063/1695 train_time:103265ms step_avg:97.15ms -step:1064/1695 train_time:103360ms step_avg:97.14ms -step:1065/1695 train_time:103456ms step_avg:97.14ms -step:1066/1695 train_time:103553ms step_avg:97.14ms -step:1067/1695 train_time:103650ms step_avg:97.14ms -step:1068/1695 train_time:103746ms step_avg:97.14ms -step:1069/1695 train_time:103842ms step_avg:97.14ms -step:1070/1695 train_time:103939ms step_avg:97.14ms -step:1071/1695 train_time:104036ms step_avg:97.14ms -step:1072/1695 train_time:104132ms step_avg:97.14ms -step:1073/1695 train_time:104227ms step_avg:97.14ms -step:1074/1695 train_time:104322ms step_avg:97.13ms -step:1075/1695 train_time:104418ms step_avg:97.13ms -step:1076/1695 train_time:104514ms step_avg:97.13ms -step:1077/1695 train_time:104609ms step_avg:97.13ms -step:1078/1695 train_time:104705ms step_avg:97.13ms -step:1079/1695 train_time:104802ms step_avg:97.13ms -step:1080/1695 train_time:104898ms step_avg:97.13ms -step:1081/1695 train_time:104995ms step_avg:97.13ms -step:1082/1695 train_time:105091ms step_avg:97.13ms -step:1083/1695 train_time:105187ms step_avg:97.13ms -step:1084/1695 train_time:105282ms step_avg:97.12ms -step:1085/1695 train_time:105378ms step_avg:97.12ms -step:1086/1695 train_time:105474ms step_avg:97.12ms -step:1087/1695 train_time:105569ms step_avg:97.12ms -step:1088/1695 train_time:105664ms step_avg:97.12ms -step:1089/1695 train_time:105761ms step_avg:97.12ms -step:1090/1695 train_time:105858ms step_avg:97.12ms -step:1091/1695 train_time:105955ms step_avg:97.12ms -step:1092/1695 train_time:106050ms step_avg:97.12ms -step:1093/1695 train_time:106146ms step_avg:97.11ms -step:1094/1695 train_time:106241ms step_avg:97.11ms -step:1095/1695 train_time:106339ms step_avg:97.11ms -step:1096/1695 train_time:106435ms step_avg:97.11ms -step:1097/1695 train_time:106531ms step_avg:97.11ms -step:1098/1695 train_time:106627ms step_avg:97.11ms -step:1099/1695 train_time:106723ms step_avg:97.11ms -step:1100/1695 train_time:106819ms step_avg:97.11ms -step:1101/1695 train_time:106915ms step_avg:97.11ms -step:1102/1695 train_time:107012ms step_avg:97.11ms -step:1103/1695 train_time:107107ms step_avg:97.11ms -step:1104/1695 train_time:107203ms step_avg:97.10ms -step:1105/1695 train_time:107300ms step_avg:97.10ms -step:1106/1695 train_time:107396ms step_avg:97.10ms -step:1107/1695 train_time:107492ms step_avg:97.10ms -step:1108/1695 train_time:107587ms step_avg:97.10ms -step:1109/1695 train_time:107683ms step_avg:97.10ms -step:1110/1695 train_time:107779ms step_avg:97.10ms -step:1111/1695 train_time:107875ms step_avg:97.10ms -step:1112/1695 train_time:107971ms step_avg:97.10ms -step:1113/1695 train_time:108066ms step_avg:97.09ms -step:1114/1695 train_time:108163ms step_avg:97.09ms -step:1115/1695 train_time:108259ms step_avg:97.09ms -step:1116/1695 train_time:108356ms step_avg:97.09ms -step:1117/1695 train_time:108453ms step_avg:97.09ms -step:1118/1695 train_time:108549ms step_avg:97.09ms -step:1119/1695 train_time:108645ms step_avg:97.09ms -step:1120/1695 train_time:108741ms step_avg:97.09ms -step:1121/1695 train_time:108838ms step_avg:97.09ms -step:1122/1695 train_time:108934ms step_avg:97.09ms -step:1123/1695 train_time:109030ms step_avg:97.09ms -step:1124/1695 train_time:109126ms step_avg:97.09ms -step:1125/1695 train_time:109222ms step_avg:97.09ms -step:1125/1695 val_loss:3.4375 train_time:109315ms step_avg:97.17ms -step:1126/1695 train_time:109340ms step_avg:97.10ms -step:1127/1695 train_time:109421ms step_avg:97.09ms -step:1128/1695 train_time:109520ms step_avg:97.09ms -step:1129/1695 train_time:109618ms step_avg:97.09ms -step:1130/1695 train_time:109713ms step_avg:97.09ms -step:1131/1695 train_time:109808ms step_avg:97.09ms -step:1132/1695 train_time:109903ms step_avg:97.09ms -step:1133/1695 train_time:110000ms step_avg:97.09ms -step:1134/1695 train_time:110096ms step_avg:97.09ms -step:1135/1695 train_time:110193ms step_avg:97.09ms -step:1136/1695 train_time:110294ms step_avg:97.09ms -step:1137/1695 train_time:110395ms step_avg:97.09ms -step:1138/1695 train_time:110495ms step_avg:97.10ms -step:1139/1695 train_time:110595ms step_avg:97.10ms -step:1140/1695 train_time:110691ms step_avg:97.10ms -step:1141/1695 train_time:110788ms step_avg:97.10ms -step:1142/1695 train_time:110886ms step_avg:97.10ms -step:1143/1695 train_time:110983ms step_avg:97.10ms -step:1144/1695 train_time:111080ms step_avg:97.10ms -step:1145/1695 train_time:111178ms step_avg:97.10ms -step:1146/1695 train_time:111276ms step_avg:97.10ms -step:1147/1695 train_time:111375ms step_avg:97.10ms -step:1148/1695 train_time:111474ms step_avg:97.10ms -step:1149/1695 train_time:111572ms step_avg:97.10ms -step:1150/1695 train_time:111669ms step_avg:97.10ms -step:1151/1695 train_time:111766ms step_avg:97.10ms -step:1152/1695 train_time:111863ms step_avg:97.10ms -step:1153/1695 train_time:111959ms step_avg:97.10ms -step:1154/1695 train_time:112057ms step_avg:97.10ms -step:1155/1695 train_time:112154ms step_avg:97.10ms -step:1156/1695 train_time:112251ms step_avg:97.10ms -step:1157/1695 train_time:112349ms step_avg:97.10ms -step:1158/1695 train_time:112448ms step_avg:97.11ms -step:1159/1695 train_time:112547ms step_avg:97.11ms -step:1160/1695 train_time:112645ms step_avg:97.11ms -step:1161/1695 train_time:112744ms step_avg:97.11ms -step:1162/1695 train_time:112843ms step_avg:97.11ms -step:1163/1695 train_time:112940ms step_avg:97.11ms -step:1164/1695 train_time:113037ms step_avg:97.11ms -step:1165/1695 train_time:113135ms step_avg:97.11ms -step:1166/1695 train_time:113232ms step_avg:97.11ms -step:1167/1695 train_time:113330ms step_avg:97.11ms -step:1168/1695 train_time:113428ms step_avg:97.11ms -step:1169/1695 train_time:113527ms step_avg:97.11ms -step:1170/1695 train_time:113626ms step_avg:97.12ms -step:1171/1695 train_time:113723ms step_avg:97.12ms -step:1172/1695 train_time:113822ms step_avg:97.12ms -step:1173/1695 train_time:113919ms step_avg:97.12ms -step:1174/1695 train_time:114018ms step_avg:97.12ms -step:1175/1695 train_time:114115ms step_avg:97.12ms -step:1176/1695 train_time:114212ms step_avg:97.12ms -step:1177/1695 train_time:114310ms step_avg:97.12ms -step:1178/1695 train_time:114406ms step_avg:97.12ms -step:1179/1695 train_time:114504ms step_avg:97.12ms -step:1180/1695 train_time:114604ms step_avg:97.12ms -step:1181/1695 train_time:114703ms step_avg:97.12ms -step:1182/1695 train_time:114801ms step_avg:97.12ms -step:1183/1695 train_time:114898ms step_avg:97.12ms -step:1184/1695 train_time:114996ms step_avg:97.13ms -step:1185/1695 train_time:115095ms step_avg:97.13ms -step:1186/1695 train_time:115193ms step_avg:97.13ms -step:1187/1695 train_time:115290ms step_avg:97.13ms -step:1188/1695 train_time:115388ms step_avg:97.13ms -step:1189/1695 train_time:115486ms step_avg:97.13ms -step:1190/1695 train_time:115584ms step_avg:97.13ms -step:1191/1695 train_time:115682ms step_avg:97.13ms -step:1192/1695 train_time:115779ms step_avg:97.13ms -step:1193/1695 train_time:115876ms step_avg:97.13ms -step:1194/1695 train_time:115973ms step_avg:97.13ms -step:1195/1695 train_time:116071ms step_avg:97.13ms -step:1196/1695 train_time:116169ms step_avg:97.13ms -step:1197/1695 train_time:116266ms step_avg:97.13ms -step:1198/1695 train_time:116365ms step_avg:97.13ms -step:1199/1695 train_time:116464ms step_avg:97.13ms -step:1200/1695 train_time:116562ms step_avg:97.14ms -step:1201/1695 train_time:116661ms step_avg:97.14ms -step:1202/1695 train_time:116758ms step_avg:97.14ms -step:1203/1695 train_time:116856ms step_avg:97.14ms -step:1204/1695 train_time:116955ms step_avg:97.14ms -step:1205/1695 train_time:117053ms step_avg:97.14ms -step:1206/1695 train_time:117150ms step_avg:97.14ms -step:1207/1695 train_time:117247ms step_avg:97.14ms -step:1208/1695 train_time:117580ms step_avg:97.33ms -step:1209/1695 train_time:117762ms step_avg:97.40ms -step:1210/1695 train_time:117858ms step_avg:97.40ms -step:1211/1695 train_time:117955ms step_avg:97.40ms -step:1212/1695 train_time:118051ms step_avg:97.40ms -step:1213/1695 train_time:118147ms step_avg:97.40ms -step:1214/1695 train_time:118243ms step_avg:97.40ms -step:1215/1695 train_time:118340ms step_avg:97.40ms -step:1216/1695 train_time:118436ms step_avg:97.40ms -step:1217/1695 train_time:118533ms step_avg:97.40ms -step:1218/1695 train_time:118636ms step_avg:97.40ms -step:1219/1695 train_time:118739ms step_avg:97.41ms -step:1220/1695 train_time:118838ms step_avg:97.41ms -step:1221/1695 train_time:118935ms step_avg:97.41ms -step:1222/1695 train_time:119031ms step_avg:97.41ms -step:1223/1695 train_time:119128ms step_avg:97.41ms -step:1224/1695 train_time:119225ms step_avg:97.41ms -step:1225/1695 train_time:119323ms step_avg:97.41ms -step:1226/1695 train_time:119420ms step_avg:97.41ms -step:1227/1695 train_time:119519ms step_avg:97.41ms -step:1228/1695 train_time:119618ms step_avg:97.41ms -step:1229/1695 train_time:119719ms step_avg:97.41ms -step:1230/1695 train_time:119818ms step_avg:97.41ms -step:1231/1695 train_time:119916ms step_avg:97.41ms -step:1232/1695 train_time:120013ms step_avg:97.41ms -step:1233/1695 train_time:120109ms step_avg:97.41ms -step:1234/1695 train_time:120206ms step_avg:97.41ms -step:1235/1695 train_time:120302ms step_avg:97.41ms -step:1236/1695 train_time:120399ms step_avg:97.41ms -step:1237/1695 train_time:120498ms step_avg:97.41ms -step:1238/1695 train_time:120595ms step_avg:97.41ms -step:1239/1695 train_time:120694ms step_avg:97.41ms -step:1240/1695 train_time:120792ms step_avg:97.41ms -step:1241/1695 train_time:120890ms step_avg:97.41ms -step:1242/1695 train_time:120987ms step_avg:97.41ms -step:1243/1695 train_time:121085ms step_avg:97.41ms -step:1244/1695 train_time:121183ms step_avg:97.41ms -step:1245/1695 train_time:121280ms step_avg:97.41ms -step:1246/1695 train_time:121377ms step_avg:97.41ms -step:1247/1695 train_time:121475ms step_avg:97.41ms -step:1248/1695 train_time:121573ms step_avg:97.41ms -step:1249/1695 train_time:121671ms step_avg:97.41ms -step:1250/1695 train_time:121769ms step_avg:97.41ms -step:1250/1695 val_loss:3.3901 train_time:121865ms step_avg:97.49ms -step:1251/1695 train_time:121890ms step_avg:97.43ms -step:1252/1695 train_time:121970ms step_avg:97.42ms -step:1253/1695 train_time:122068ms step_avg:97.42ms -step:1254/1695 train_time:122165ms step_avg:97.42ms -step:1255/1695 train_time:122262ms step_avg:97.42ms -step:1256/1695 train_time:122358ms step_avg:97.42ms -step:1257/1695 train_time:122455ms step_avg:97.42ms -step:1258/1695 train_time:122552ms step_avg:97.42ms -step:1259/1695 train_time:122648ms step_avg:97.42ms -step:1260/1695 train_time:122746ms step_avg:97.42ms -step:1261/1695 train_time:122850ms step_avg:97.42ms -step:1262/1695 train_time:122952ms step_avg:97.43ms -step:1263/1695 train_time:123051ms step_avg:97.43ms -step:1264/1695 train_time:123148ms step_avg:97.43ms -step:1265/1695 train_time:123246ms step_avg:97.43ms -step:1266/1695 train_time:123344ms step_avg:97.43ms -step:1267/1695 train_time:123440ms step_avg:97.43ms -step:1268/1695 train_time:123537ms step_avg:97.43ms -step:1269/1695 train_time:123633ms step_avg:97.43ms -step:1270/1695 train_time:123732ms step_avg:97.43ms -step:1271/1695 train_time:123833ms step_avg:97.43ms -step:1272/1695 train_time:123933ms step_avg:97.43ms -step:1273/1695 train_time:124031ms step_avg:97.43ms -step:1274/1695 train_time:124131ms step_avg:97.43ms -step:1275/1695 train_time:124230ms step_avg:97.44ms -step:1276/1695 train_time:124329ms step_avg:97.44ms -step:1277/1695 train_time:124427ms step_avg:97.44ms -step:1278/1695 train_time:124525ms step_avg:97.44ms -step:1279/1695 train_time:124622ms step_avg:97.44ms -step:1280/1695 train_time:124719ms step_avg:97.44ms -step:1281/1695 train_time:124817ms step_avg:97.44ms -step:1282/1695 train_time:124915ms step_avg:97.44ms -step:1283/1695 train_time:125014ms step_avg:97.44ms -step:1284/1695 train_time:125113ms step_avg:97.44ms -step:1285/1695 train_time:125211ms step_avg:97.44ms -step:1286/1695 train_time:125310ms step_avg:97.44ms -step:1287/1695 train_time:125409ms step_avg:97.44ms -step:1288/1695 train_time:125508ms step_avg:97.44ms -step:1289/1695 train_time:125606ms step_avg:97.44ms -step:1290/1695 train_time:125703ms step_avg:97.44ms -step:1291/1695 train_time:125801ms step_avg:97.44ms -step:1292/1695 train_time:125898ms step_avg:97.44ms -step:1293/1695 train_time:125996ms step_avg:97.44ms -step:1294/1695 train_time:126094ms step_avg:97.44ms -step:1295/1695 train_time:126193ms step_avg:97.45ms -step:1296/1695 train_time:126291ms step_avg:97.45ms -step:1297/1695 train_time:126390ms step_avg:97.45ms -step:1298/1695 train_time:126489ms step_avg:97.45ms -step:1299/1695 train_time:126586ms step_avg:97.45ms -step:1300/1695 train_time:126684ms step_avg:97.45ms -step:1301/1695 train_time:126781ms step_avg:97.45ms -step:1302/1695 train_time:126879ms step_avg:97.45ms -step:1303/1695 train_time:126976ms step_avg:97.45ms -step:1304/1695 train_time:127074ms step_avg:97.45ms -step:1305/1695 train_time:127172ms step_avg:97.45ms -step:1306/1695 train_time:127271ms step_avg:97.45ms -step:1307/1695 train_time:127369ms step_avg:97.45ms -step:1308/1695 train_time:127467ms step_avg:97.45ms -step:1309/1695 train_time:127566ms step_avg:97.45ms -step:1310/1695 train_time:127664ms step_avg:97.45ms -step:1311/1695 train_time:127762ms step_avg:97.45ms -step:1312/1695 train_time:127860ms step_avg:97.45ms -step:1313/1695 train_time:127957ms step_avg:97.45ms -step:1314/1695 train_time:128054ms step_avg:97.45ms -step:1315/1695 train_time:128151ms step_avg:97.45ms -step:1316/1695 train_time:128250ms step_avg:97.45ms -step:1317/1695 train_time:128349ms step_avg:97.46ms -step:1318/1695 train_time:128448ms step_avg:97.46ms -step:1319/1695 train_time:128546ms step_avg:97.46ms -step:1320/1695 train_time:128645ms step_avg:97.46ms -step:1321/1695 train_time:128743ms step_avg:97.46ms -step:1322/1695 train_time:128840ms step_avg:97.46ms -step:1323/1695 train_time:128938ms step_avg:97.46ms -step:1324/1695 train_time:129035ms step_avg:97.46ms -step:1325/1695 train_time:129132ms step_avg:97.46ms -step:1326/1695 train_time:129230ms step_avg:97.46ms -step:1327/1695 train_time:129328ms step_avg:97.46ms -step:1328/1695 train_time:129427ms step_avg:97.46ms -step:1329/1695 train_time:129525ms step_avg:97.46ms -step:1330/1695 train_time:129623ms step_avg:97.46ms -step:1331/1695 train_time:129721ms step_avg:97.46ms -step:1332/1695 train_time:129820ms step_avg:97.46ms -step:1333/1695 train_time:129917ms step_avg:97.46ms -step:1334/1695 train_time:130014ms step_avg:97.46ms -step:1335/1695 train_time:130112ms step_avg:97.46ms -step:1336/1695 train_time:130209ms step_avg:97.46ms -step:1337/1695 train_time:130306ms step_avg:97.46ms -step:1338/1695 train_time:130405ms step_avg:97.46ms -step:1339/1695 train_time:130503ms step_avg:97.46ms -step:1340/1695 train_time:130600ms step_avg:97.46ms -step:1341/1695 train_time:130698ms step_avg:97.46ms -step:1342/1695 train_time:130795ms step_avg:97.46ms -step:1343/1695 train_time:130894ms step_avg:97.46ms -step:1344/1695 train_time:130992ms step_avg:97.46ms -step:1345/1695 train_time:131090ms step_avg:97.46ms -step:1346/1695 train_time:131188ms step_avg:97.46ms -step:1347/1695 train_time:131285ms step_avg:97.46ms -step:1348/1695 train_time:131382ms step_avg:97.46ms -step:1349/1695 train_time:131480ms step_avg:97.46ms -step:1350/1695 train_time:131577ms step_avg:97.46ms -step:1351/1695 train_time:131675ms step_avg:97.46ms -step:1352/1695 train_time:131773ms step_avg:97.47ms -step:1353/1695 train_time:131872ms step_avg:97.47ms -step:1354/1695 train_time:131970ms step_avg:97.47ms -step:1355/1695 train_time:132069ms step_avg:97.47ms -step:1356/1695 train_time:132167ms step_avg:97.47ms -step:1357/1695 train_time:132264ms step_avg:97.47ms -step:1358/1695 train_time:132362ms step_avg:97.47ms -step:1359/1695 train_time:132460ms step_avg:97.47ms -step:1360/1695 train_time:132556ms step_avg:97.47ms -step:1361/1695 train_time:132654ms step_avg:97.47ms -step:1362/1695 train_time:132752ms step_avg:97.47ms -step:1363/1695 train_time:132851ms step_avg:97.47ms -step:1364/1695 train_time:132950ms step_avg:97.47ms -step:1365/1695 train_time:133049ms step_avg:97.47ms -step:1366/1695 train_time:133147ms step_avg:97.47ms -step:1367/1695 train_time:133246ms step_avg:97.47ms -step:1368/1695 train_time:133344ms step_avg:97.47ms -step:1369/1695 train_time:133443ms step_avg:97.47ms -step:1370/1695 train_time:133540ms step_avg:97.47ms -step:1371/1695 train_time:133637ms step_avg:97.47ms -step:1372/1695 train_time:133734ms step_avg:97.47ms -step:1373/1695 train_time:133830ms step_avg:97.47ms -step:1374/1695 train_time:133929ms step_avg:97.47ms -step:1375/1695 train_time:134027ms step_avg:97.47ms -step:1375/1695 val_loss:3.3508 train_time:134124ms step_avg:97.54ms -step:1376/1695 train_time:134149ms step_avg:97.49ms -step:1377/1695 train_time:134233ms step_avg:97.48ms -step:1378/1695 train_time:134332ms step_avg:97.48ms -step:1379/1695 train_time:134430ms step_avg:97.48ms -step:1380/1695 train_time:134528ms step_avg:97.48ms -step:1381/1695 train_time:134984ms step_avg:97.74ms -step:1382/1695 train_time:135059ms step_avg:97.73ms -step:1383/1695 train_time:135155ms step_avg:97.73ms -step:1384/1695 train_time:135251ms step_avg:97.72ms -step:1385/1695 train_time:135348ms step_avg:97.72ms -step:1386/1695 train_time:135445ms step_avg:97.72ms -step:1387/1695 train_time:135542ms step_avg:97.72ms -step:1388/1695 train_time:135638ms step_avg:97.72ms -step:1389/1695 train_time:135734ms step_avg:97.72ms -step:1390/1695 train_time:135832ms step_avg:97.72ms -step:1391/1695 train_time:135934ms step_avg:97.72ms -step:1392/1695 train_time:136034ms step_avg:97.73ms -step:1393/1695 train_time:136133ms step_avg:97.73ms -step:1394/1695 train_time:136230ms step_avg:97.73ms -step:1395/1695 train_time:136328ms step_avg:97.73ms -step:1396/1695 train_time:136425ms step_avg:97.73ms -step:1397/1695 train_time:136523ms step_avg:97.73ms -step:1398/1695 train_time:136620ms step_avg:97.73ms -step:1399/1695 train_time:136717ms step_avg:97.72ms -step:1400/1695 train_time:136814ms step_avg:97.72ms -step:1401/1695 train_time:136913ms step_avg:97.73ms -step:1402/1695 train_time:137012ms step_avg:97.73ms -step:1403/1695 train_time:137111ms step_avg:97.73ms -step:1404/1695 train_time:137211ms step_avg:97.73ms -step:1405/1695 train_time:137309ms step_avg:97.73ms -step:1406/1695 train_time:137407ms step_avg:97.73ms -step:1407/1695 train_time:137504ms step_avg:97.73ms -step:1408/1695 train_time:137601ms step_avg:97.73ms -step:1409/1695 train_time:137698ms step_avg:97.73ms -step:1410/1695 train_time:137795ms step_avg:97.73ms -step:1411/1695 train_time:137894ms step_avg:97.73ms -step:1412/1695 train_time:137993ms step_avg:97.73ms -step:1413/1695 train_time:138092ms step_avg:97.73ms -step:1414/1695 train_time:138191ms step_avg:97.73ms -step:1415/1695 train_time:138289ms step_avg:97.73ms -step:1416/1695 train_time:138386ms step_avg:97.73ms -step:1417/1695 train_time:138483ms step_avg:97.73ms -step:1418/1695 train_time:138580ms step_avg:97.73ms -step:1419/1695 train_time:138677ms step_avg:97.73ms -step:1420/1695 train_time:138774ms step_avg:97.73ms -step:1421/1695 train_time:138873ms step_avg:97.73ms -step:1422/1695 train_time:138971ms step_avg:97.73ms -step:1423/1695 train_time:139071ms step_avg:97.73ms -step:1424/1695 train_time:139170ms step_avg:97.73ms -step:1425/1695 train_time:139269ms step_avg:97.73ms -step:1426/1695 train_time:139367ms step_avg:97.73ms -step:1427/1695 train_time:139467ms step_avg:97.73ms -step:1428/1695 train_time:139565ms step_avg:97.73ms -step:1429/1695 train_time:139662ms step_avg:97.73ms -step:1430/1695 train_time:139760ms step_avg:97.73ms -step:1431/1695 train_time:139857ms step_avg:97.73ms -step:1432/1695 train_time:139954ms step_avg:97.73ms -step:1433/1695 train_time:140052ms step_avg:97.73ms -step:1434/1695 train_time:140149ms step_avg:97.73ms -step:1435/1695 train_time:140246ms step_avg:97.73ms -step:1436/1695 train_time:140344ms step_avg:97.73ms -step:1437/1695 train_time:140441ms step_avg:97.73ms -step:1438/1695 train_time:140538ms step_avg:97.73ms -step:1439/1695 train_time:140635ms step_avg:97.73ms -step:1440/1695 train_time:140733ms step_avg:97.73ms -step:1441/1695 train_time:140831ms step_avg:97.73ms -step:1442/1695 train_time:140929ms step_avg:97.73ms -step:1443/1695 train_time:141027ms step_avg:97.73ms -step:1444/1695 train_time:141125ms step_avg:97.73ms -step:1445/1695 train_time:141223ms step_avg:97.73ms -step:1446/1695 train_time:141320ms step_avg:97.73ms -step:1447/1695 train_time:141417ms step_avg:97.73ms -step:1448/1695 train_time:141514ms step_avg:97.73ms -step:1449/1695 train_time:141612ms step_avg:97.73ms -step:1450/1695 train_time:141710ms step_avg:97.73ms -step:1451/1695 train_time:141808ms step_avg:97.73ms -step:1452/1695 train_time:141905ms step_avg:97.73ms -step:1453/1695 train_time:142003ms step_avg:97.73ms -step:1454/1695 train_time:142100ms step_avg:97.73ms -step:1455/1695 train_time:142198ms step_avg:97.73ms -step:1456/1695 train_time:142295ms step_avg:97.73ms -step:1457/1695 train_time:142393ms step_avg:97.73ms -step:1458/1695 train_time:142491ms step_avg:97.73ms -step:1459/1695 train_time:142590ms step_avg:97.73ms -step:1460/1695 train_time:142687ms step_avg:97.73ms -step:1461/1695 train_time:142784ms step_avg:97.73ms -step:1462/1695 train_time:142881ms step_avg:97.73ms -step:1463/1695 train_time:142979ms step_avg:97.73ms -step:1464/1695 train_time:143076ms step_avg:97.73ms -step:1465/1695 train_time:143174ms step_avg:97.73ms -step:1466/1695 train_time:143271ms step_avg:97.73ms -step:1467/1695 train_time:143369ms step_avg:97.73ms -step:1468/1695 train_time:143468ms step_avg:97.73ms -step:1469/1695 train_time:143568ms step_avg:97.73ms -step:1470/1695 train_time:143665ms step_avg:97.73ms -step:1471/1695 train_time:143763ms step_avg:97.73ms -step:1472/1695 train_time:143860ms step_avg:97.73ms -step:1473/1695 train_time:143957ms step_avg:97.73ms -step:1474/1695 train_time:144054ms step_avg:97.73ms -step:1475/1695 train_time:144151ms step_avg:97.73ms -step:1476/1695 train_time:144248ms step_avg:97.73ms -step:1477/1695 train_time:144346ms step_avg:97.73ms -step:1478/1695 train_time:144444ms step_avg:97.73ms -step:1479/1695 train_time:144542ms step_avg:97.73ms -step:1480/1695 train_time:144640ms step_avg:97.73ms -step:1481/1695 train_time:144737ms step_avg:97.73ms -step:1482/1695 train_time:144834ms step_avg:97.73ms -step:1483/1695 train_time:144931ms step_avg:97.73ms -step:1484/1695 train_time:145029ms step_avg:97.73ms -step:1485/1695 train_time:145127ms step_avg:97.73ms -step:1486/1695 train_time:145225ms step_avg:97.73ms -step:1487/1695 train_time:145322ms step_avg:97.73ms -step:1488/1695 train_time:145420ms step_avg:97.73ms -step:1489/1695 train_time:145518ms step_avg:97.73ms -step:1490/1695 train_time:145616ms step_avg:97.73ms -step:1491/1695 train_time:145714ms step_avg:97.73ms -step:1492/1695 train_time:145811ms step_avg:97.73ms -step:1493/1695 train_time:145909ms step_avg:97.73ms -step:1494/1695 train_time:146007ms step_avg:97.73ms -step:1495/1695 train_time:146105ms step_avg:97.73ms -step:1496/1695 train_time:146203ms step_avg:97.73ms -step:1497/1695 train_time:146300ms step_avg:97.73ms -step:1498/1695 train_time:146397ms step_avg:97.73ms -step:1499/1695 train_time:146494ms step_avg:97.73ms -step:1500/1695 train_time:146592ms step_avg:97.73ms -step:1500/1695 val_loss:3.3185 train_time:146688ms step_avg:97.79ms -step:1501/1695 train_time:146713ms step_avg:97.74ms -step:1502/1695 train_time:146798ms step_avg:97.74ms -step:1503/1695 train_time:146898ms step_avg:97.74ms -step:1504/1695 train_time:146996ms step_avg:97.74ms -step:1505/1695 train_time:147093ms step_avg:97.74ms -step:1506/1695 train_time:147190ms step_avg:97.74ms -step:1507/1695 train_time:147287ms step_avg:97.73ms -step:1508/1695 train_time:147383ms step_avg:97.73ms -step:1509/1695 train_time:147479ms step_avg:97.73ms -step:1510/1695 train_time:147576ms step_avg:97.73ms -step:1511/1695 train_time:147675ms step_avg:97.73ms -step:1512/1695 train_time:147778ms step_avg:97.74ms -step:1513/1695 train_time:147878ms step_avg:97.74ms -step:1514/1695 train_time:147977ms step_avg:97.74ms -step:1515/1695 train_time:148075ms step_avg:97.74ms -step:1516/1695 train_time:148173ms step_avg:97.74ms -step:1517/1695 train_time:148270ms step_avg:97.74ms -step:1518/1695 train_time:148368ms step_avg:97.74ms -step:1519/1695 train_time:148465ms step_avg:97.74ms -step:1520/1695 train_time:148562ms step_avg:97.74ms -step:1521/1695 train_time:148659ms step_avg:97.74ms -step:1522/1695 train_time:148758ms step_avg:97.74ms -step:1523/1695 train_time:148858ms step_avg:97.74ms -step:1524/1695 train_time:148957ms step_avg:97.74ms -step:1525/1695 train_time:149056ms step_avg:97.74ms -step:1526/1695 train_time:149154ms step_avg:97.74ms -step:1527/1695 train_time:149252ms step_avg:97.74ms -step:1528/1695 train_time:149349ms step_avg:97.74ms -step:1529/1695 train_time:149446ms step_avg:97.74ms -step:1530/1695 train_time:149544ms step_avg:97.74ms -step:1531/1695 train_time:149641ms step_avg:97.74ms -step:1532/1695 train_time:149739ms step_avg:97.74ms -step:1533/1695 train_time:149836ms step_avg:97.74ms -step:1534/1695 train_time:149935ms step_avg:97.74ms -step:1535/1695 train_time:150033ms step_avg:97.74ms -step:1536/1695 train_time:150131ms step_avg:97.74ms -step:1537/1695 train_time:150228ms step_avg:97.74ms -step:1538/1695 train_time:150325ms step_avg:97.74ms -step:1539/1695 train_time:150423ms step_avg:97.74ms -step:1540/1695 train_time:150520ms step_avg:97.74ms -step:1541/1695 train_time:150617ms step_avg:97.74ms -step:1542/1695 train_time:150715ms step_avg:97.74ms -step:1543/1695 train_time:150815ms step_avg:97.74ms -step:1544/1695 train_time:150914ms step_avg:97.74ms -step:1545/1695 train_time:151013ms step_avg:97.74ms -step:1546/1695 train_time:151111ms step_avg:97.74ms -step:1547/1695 train_time:151209ms step_avg:97.74ms -step:1548/1695 train_time:151307ms step_avg:97.74ms -step:1549/1695 train_time:151404ms step_avg:97.74ms -step:1550/1695 train_time:151502ms step_avg:97.74ms -step:1551/1695 train_time:151599ms step_avg:97.74ms -step:1552/1695 train_time:152044ms step_avg:97.97ms -step:1553/1695 train_time:152117ms step_avg:97.95ms -step:1554/1695 train_time:152213ms step_avg:97.95ms -step:1555/1695 train_time:152309ms step_avg:97.95ms -step:1556/1695 train_time:152406ms step_avg:97.95ms -step:1557/1695 train_time:152502ms step_avg:97.95ms -step:1558/1695 train_time:152598ms step_avg:97.94ms -step:1559/1695 train_time:152694ms step_avg:97.94ms -step:1560/1695 train_time:152791ms step_avg:97.94ms -step:1561/1695 train_time:152888ms step_avg:97.94ms -step:1562/1695 train_time:152990ms step_avg:97.95ms -step:1563/1695 train_time:153092ms step_avg:97.95ms -step:1564/1695 train_time:153194ms step_avg:97.95ms -step:1565/1695 train_time:153294ms step_avg:97.95ms -step:1566/1695 train_time:153391ms step_avg:97.95ms -step:1567/1695 train_time:153490ms step_avg:97.95ms -step:1568/1695 train_time:153587ms step_avg:97.95ms -step:1569/1695 train_time:153684ms step_avg:97.95ms -step:1570/1695 train_time:153780ms step_avg:97.95ms -step:1571/1695 train_time:153877ms step_avg:97.95ms -step:1572/1695 train_time:153976ms step_avg:97.95ms -step:1573/1695 train_time:154077ms step_avg:97.95ms -step:1574/1695 train_time:154176ms step_avg:97.95ms -step:1575/1695 train_time:154276ms step_avg:97.95ms -step:1576/1695 train_time:154375ms step_avg:97.95ms -step:1577/1695 train_time:154473ms step_avg:97.95ms -step:1578/1695 train_time:154571ms step_avg:97.95ms -step:1579/1695 train_time:154669ms step_avg:97.95ms -step:1580/1695 train_time:154766ms step_avg:97.95ms -step:1581/1695 train_time:154863ms step_avg:97.95ms -step:1582/1695 train_time:154961ms step_avg:97.95ms -step:1583/1695 train_time:155059ms step_avg:97.95ms -step:1584/1695 train_time:155156ms step_avg:97.95ms -step:1585/1695 train_time:155255ms step_avg:97.95ms -step:1586/1695 train_time:155354ms step_avg:97.95ms -step:1587/1695 train_time:155454ms step_avg:97.95ms -step:1588/1695 train_time:155553ms step_avg:97.96ms -step:1589/1695 train_time:155651ms step_avg:97.96ms -step:1590/1695 train_time:155749ms step_avg:97.96ms -step:1591/1695 train_time:155847ms step_avg:97.96ms -step:1592/1695 train_time:155946ms step_avg:97.96ms -step:1593/1695 train_time:156045ms step_avg:97.96ms -step:1594/1695 train_time:156143ms step_avg:97.96ms -step:1595/1695 train_time:156241ms step_avg:97.96ms -step:1596/1695 train_time:156339ms step_avg:97.96ms -step:1597/1695 train_time:156436ms step_avg:97.96ms -step:1598/1695 train_time:156534ms step_avg:97.96ms -step:1599/1695 train_time:156632ms step_avg:97.96ms -step:1600/1695 train_time:156731ms step_avg:97.96ms -step:1601/1695 train_time:156829ms step_avg:97.96ms -step:1602/1695 train_time:156928ms step_avg:97.96ms -step:1603/1695 train_time:157027ms step_avg:97.96ms -step:1604/1695 train_time:157126ms step_avg:97.96ms -step:1605/1695 train_time:157224ms step_avg:97.96ms -step:1606/1695 train_time:157323ms step_avg:97.96ms -step:1607/1695 train_time:157420ms step_avg:97.96ms -step:1608/1695 train_time:157518ms step_avg:97.96ms -step:1609/1695 train_time:157615ms step_avg:97.96ms -step:1610/1695 train_time:157712ms step_avg:97.96ms -step:1611/1695 train_time:157810ms step_avg:97.96ms -step:1612/1695 train_time:157909ms step_avg:97.96ms -step:1613/1695 train_time:158007ms step_avg:97.96ms -step:1614/1695 train_time:158106ms step_avg:97.96ms -step:1615/1695 train_time:158204ms step_avg:97.96ms -step:1616/1695 train_time:158302ms step_avg:97.96ms -step:1617/1695 train_time:158400ms step_avg:97.96ms -step:1618/1695 train_time:158498ms step_avg:97.96ms -step:1619/1695 train_time:158594ms step_avg:97.96ms -step:1620/1695 train_time:158691ms step_avg:97.96ms -step:1621/1695 train_time:158789ms step_avg:97.96ms -step:1622/1695 train_time:158887ms step_avg:97.96ms -step:1623/1695 train_time:158984ms step_avg:97.96ms -step:1624/1695 train_time:159083ms step_avg:97.96ms -step:1625/1695 train_time:159181ms step_avg:97.96ms -step:1625/1695 val_loss:3.2909 train_time:159277ms step_avg:98.02ms -step:1626/1695 train_time:159302ms step_avg:97.97ms -step:1627/1695 train_time:159385ms step_avg:97.96ms -step:1628/1695 train_time:159484ms step_avg:97.96ms -step:1629/1695 train_time:159582ms step_avg:97.96ms -step:1630/1695 train_time:159678ms step_avg:97.96ms -step:1631/1695 train_time:159775ms step_avg:97.96ms -step:1632/1695 train_time:159872ms step_avg:97.96ms -step:1633/1695 train_time:159970ms step_avg:97.96ms -step:1634/1695 train_time:160066ms step_avg:97.96ms -step:1635/1695 train_time:160163ms step_avg:97.96ms -step:1636/1695 train_time:160263ms step_avg:97.96ms -step:1637/1695 train_time:160363ms step_avg:97.96ms -step:1638/1695 train_time:160463ms step_avg:97.96ms -step:1639/1695 train_time:160561ms step_avg:97.96ms -step:1640/1695 train_time:160658ms step_avg:97.96ms -step:1641/1695 train_time:160756ms step_avg:97.96ms -step:1642/1695 train_time:160853ms step_avg:97.96ms -step:1643/1695 train_time:160949ms step_avg:97.96ms -step:1644/1695 train_time:161046ms step_avg:97.96ms -step:1645/1695 train_time:161142ms step_avg:97.96ms -step:1646/1695 train_time:161240ms step_avg:97.96ms -step:1647/1695 train_time:161339ms step_avg:97.96ms -step:1648/1695 train_time:161438ms step_avg:97.96ms -step:1649/1695 train_time:161536ms step_avg:97.96ms -step:1650/1695 train_time:161634ms step_avg:97.96ms -step:1651/1695 train_time:161732ms step_avg:97.96ms -step:1652/1695 train_time:161831ms step_avg:97.96ms -step:1653/1695 train_time:161930ms step_avg:97.96ms -step:1654/1695 train_time:162027ms step_avg:97.96ms -step:1655/1695 train_time:162124ms step_avg:97.96ms -step:1656/1695 train_time:162223ms step_avg:97.96ms -step:1657/1695 train_time:162320ms step_avg:97.96ms -step:1658/1695 train_time:162418ms step_avg:97.96ms -step:1659/1695 train_time:162516ms step_avg:97.96ms -step:1660/1695 train_time:162614ms step_avg:97.96ms -step:1661/1695 train_time:162712ms step_avg:97.96ms -step:1662/1695 train_time:162811ms step_avg:97.96ms -step:1663/1695 train_time:162909ms step_avg:97.96ms -step:1664/1695 train_time:163007ms step_avg:97.96ms -step:1665/1695 train_time:163105ms step_avg:97.96ms -step:1666/1695 train_time:163204ms step_avg:97.96ms -step:1667/1695 train_time:163303ms step_avg:97.96ms -step:1668/1695 train_time:163402ms step_avg:97.96ms -step:1669/1695 train_time:163499ms step_avg:97.96ms -step:1670/1695 train_time:163597ms step_avg:97.96ms -step:1671/1695 train_time:163694ms step_avg:97.96ms -step:1672/1695 train_time:163791ms step_avg:97.96ms -step:1673/1695 train_time:163889ms step_avg:97.96ms -step:1674/1695 train_time:163987ms step_avg:97.96ms -step:1675/1695 train_time:164085ms step_avg:97.96ms -step:1676/1695 train_time:164183ms step_avg:97.96ms -step:1677/1695 train_time:164281ms step_avg:97.96ms -step:1678/1695 train_time:164379ms step_avg:97.96ms -step:1679/1695 train_time:164476ms step_avg:97.96ms -step:1680/1695 train_time:164573ms step_avg:97.96ms -step:1681/1695 train_time:164671ms step_avg:97.96ms -step:1682/1695 train_time:164768ms step_avg:97.96ms -step:1683/1695 train_time:164866ms step_avg:97.96ms -step:1684/1695 train_time:164963ms step_avg:97.96ms -step:1685/1695 train_time:165060ms step_avg:97.96ms -step:1686/1695 train_time:165157ms step_avg:97.96ms -step:1687/1695 train_time:165255ms step_avg:97.96ms -step:1688/1695 train_time:165354ms step_avg:97.96ms -step:1689/1695 train_time:165453ms step_avg:97.96ms -step:1690/1695 train_time:165551ms step_avg:97.96ms -step:1691/1695 train_time:165650ms step_avg:97.96ms -step:1692/1695 train_time:165748ms step_avg:97.96ms -step:1693/1695 train_time:165845ms step_avg:97.96ms -step:1694/1695 train_time:165943ms step_avg:97.96ms -step:1695/1695 train_time:166040ms step_avg:97.96ms -step:1695/1695 val_loss:3.2791 train_time:166135ms step_avg:98.01ms -peak memory allocated: 34000 MiB reserved: 49416 MiB diff --git a/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt b/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt deleted file mode 100644 index f68fe219a..000000000 --- a/records/082725_FA3/bb331245-5e49-4366-b902-6caff64ed8d6.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 04:10:32 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 29C P0 114W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 31C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 30C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 29C P0 111W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 33C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 32C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 31C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:507ms step_avg:507.01ms -step:2/1695 train_time:531ms step_avg:265.61ms -step:3/1695 train_time:603ms step_avg:201.15ms -step:4/1695 train_time:695ms step_avg:173.81ms -step:5/1695 train_time:789ms step_avg:157.80ms -step:6/1695 train_time:883ms step_avg:147.20ms -step:7/1695 train_time:976ms step_avg:139.42ms -step:8/1695 train_time:1069ms step_avg:133.63ms -step:9/1695 train_time:1162ms step_avg:129.16ms -step:10/1695 train_time:1256ms step_avg:125.61ms -step:11/1695 train_time:1350ms step_avg:122.74ms -step:12/1695 train_time:1446ms step_avg:120.48ms -step:13/1695 train_time:1542ms step_avg:118.60ms -step:14/1695 train_time:1637ms step_avg:116.89ms -step:15/1695 train_time:1731ms step_avg:115.41ms -step:16/1695 train_time:1827ms step_avg:114.16ms -step:17/1695 train_time:1921ms step_avg:112.98ms -step:18/1695 train_time:2014ms step_avg:111.90ms -step:19/1695 train_time:2108ms step_avg:110.97ms -step:20/1695 train_time:2204ms step_avg:110.18ms -step:21/1695 train_time:2297ms step_avg:109.37ms -step:22/1695 train_time:2392ms step_avg:108.71ms -step:23/1695 train_time:2487ms step_avg:108.14ms -step:24/1695 train_time:2583ms step_avg:107.62ms -step:25/1695 train_time:2677ms step_avg:107.07ms -step:26/1695 train_time:2771ms step_avg:106.58ms -step:27/1695 train_time:2866ms step_avg:106.14ms -step:28/1695 train_time:2960ms step_avg:105.72ms -step:29/1695 train_time:3054ms step_avg:105.31ms -step:30/1695 train_time:3149ms step_avg:104.96ms -step:31/1695 train_time:3243ms step_avg:104.62ms -step:32/1695 train_time:3337ms step_avg:104.27ms -step:33/1695 train_time:3432ms step_avg:104.01ms -step:34/1695 train_time:3529ms step_avg:103.78ms -step:35/1695 train_time:3624ms step_avg:103.56ms -step:36/1695 train_time:3719ms step_avg:103.30ms -step:37/1695 train_time:3814ms step_avg:103.07ms -step:38/1695 train_time:3909ms step_avg:102.86ms -step:39/1695 train_time:4004ms step_avg:102.66ms -step:40/1695 train_time:4098ms step_avg:102.45ms -step:41/1695 train_time:4192ms step_avg:102.24ms -step:42/1695 train_time:4286ms step_avg:102.04ms -step:43/1695 train_time:4380ms step_avg:101.86ms -step:44/1695 train_time:4474ms step_avg:101.68ms -step:45/1695 train_time:4569ms step_avg:101.53ms -step:46/1695 train_time:4664ms step_avg:101.39ms -step:47/1695 train_time:4757ms step_avg:101.22ms -step:48/1695 train_time:4852ms step_avg:101.07ms -step:49/1695 train_time:4947ms step_avg:100.95ms -step:50/1695 train_time:5041ms step_avg:100.82ms -step:51/1695 train_time:5135ms step_avg:100.68ms -step:52/1695 train_time:5229ms step_avg:100.56ms -step:53/1695 train_time:5324ms step_avg:100.45ms -step:54/1695 train_time:5418ms step_avg:100.33ms -step:55/1695 train_time:5513ms step_avg:100.23ms -step:56/1695 train_time:5608ms step_avg:100.14ms -step:57/1695 train_time:5703ms step_avg:100.05ms -step:58/1695 train_time:5796ms step_avg:99.94ms -step:59/1695 train_time:5891ms step_avg:99.85ms -step:60/1695 train_time:5986ms step_avg:99.76ms -step:61/1695 train_time:6079ms step_avg:99.66ms -step:62/1695 train_time:6173ms step_avg:99.56ms -step:63/1695 train_time:6268ms step_avg:99.50ms -step:64/1695 train_time:6364ms step_avg:99.44ms -step:65/1695 train_time:6458ms step_avg:99.36ms -step:66/1695 train_time:6553ms step_avg:99.29ms -step:67/1695 train_time:6649ms step_avg:99.24ms -step:68/1695 train_time:6745ms step_avg:99.19ms -step:69/1695 train_time:6837ms step_avg:99.08ms -step:70/1695 train_time:6931ms step_avg:99.02ms -step:71/1695 train_time:7026ms step_avg:98.95ms -step:72/1695 train_time:7119ms step_avg:98.87ms -step:73/1695 train_time:7213ms step_avg:98.81ms -step:74/1695 train_time:7307ms step_avg:98.75ms -step:75/1695 train_time:7402ms step_avg:98.69ms -step:76/1695 train_time:7495ms step_avg:98.62ms -step:77/1695 train_time:7590ms step_avg:98.58ms -step:78/1695 train_time:7686ms step_avg:98.53ms -step:79/1695 train_time:7779ms step_avg:98.47ms -step:80/1695 train_time:7873ms step_avg:98.41ms -step:81/1695 train_time:7969ms step_avg:98.38ms -step:82/1695 train_time:8064ms step_avg:98.34ms -step:83/1695 train_time:8158ms step_avg:98.28ms -step:84/1695 train_time:8252ms step_avg:98.23ms -step:85/1695 train_time:8346ms step_avg:98.18ms -step:86/1695 train_time:8439ms step_avg:98.13ms -step:87/1695 train_time:8533ms step_avg:98.08ms -step:88/1695 train_time:8629ms step_avg:98.06ms -step:89/1695 train_time:8724ms step_avg:98.02ms -step:90/1695 train_time:8817ms step_avg:97.97ms -step:91/1695 train_time:8911ms step_avg:97.93ms -step:92/1695 train_time:9006ms step_avg:97.89ms -step:93/1695 train_time:9100ms step_avg:97.85ms -step:94/1695 train_time:9194ms step_avg:97.80ms -step:95/1695 train_time:9289ms step_avg:97.77ms -step:96/1695 train_time:9383ms step_avg:97.74ms -step:97/1695 train_time:9476ms step_avg:97.69ms -step:98/1695 train_time:9570ms step_avg:97.66ms -step:99/1695 train_time:9665ms step_avg:97.63ms -step:100/1695 train_time:9760ms step_avg:97.60ms -step:101/1695 train_time:9854ms step_avg:97.56ms -step:102/1695 train_time:9949ms step_avg:97.54ms -step:103/1695 train_time:10043ms step_avg:97.51ms -step:104/1695 train_time:10137ms step_avg:97.47ms -step:105/1695 train_time:10231ms step_avg:97.44ms -step:106/1695 train_time:10326ms step_avg:97.41ms -step:107/1695 train_time:10420ms step_avg:97.38ms -step:108/1695 train_time:10513ms step_avg:97.34ms -step:109/1695 train_time:10608ms step_avg:97.32ms -step:110/1695 train_time:10703ms step_avg:97.30ms -step:111/1695 train_time:10797ms step_avg:97.27ms -step:112/1695 train_time:10891ms step_avg:97.24ms -step:113/1695 train_time:10985ms step_avg:97.21ms -step:114/1695 train_time:11078ms step_avg:97.18ms -step:115/1695 train_time:11172ms step_avg:97.15ms -step:116/1695 train_time:11267ms step_avg:97.13ms -step:117/1695 train_time:11362ms step_avg:97.11ms -step:118/1695 train_time:11456ms step_avg:97.08ms -step:119/1695 train_time:11550ms step_avg:97.06ms -step:120/1695 train_time:11644ms step_avg:97.03ms -step:121/1695 train_time:11737ms step_avg:97.00ms -step:122/1695 train_time:11832ms step_avg:96.98ms -step:123/1695 train_time:11927ms step_avg:96.97ms -step:124/1695 train_time:12022ms step_avg:96.95ms -step:125/1695 train_time:12115ms step_avg:96.92ms -step:125/1695 val_loss:4.3104 train_time:12207ms step_avg:97.66ms -step:126/1695 train_time:12232ms step_avg:97.08ms -step:127/1695 train_time:12310ms step_avg:96.93ms -step:128/1695 train_time:12410ms step_avg:96.96ms -step:129/1695 train_time:12506ms step_avg:96.94ms -step:130/1695 train_time:12599ms step_avg:96.92ms -step:131/1695 train_time:12692ms step_avg:96.89ms -step:132/1695 train_time:12785ms step_avg:96.86ms -step:133/1695 train_time:12878ms step_avg:96.83ms -step:134/1695 train_time:12972ms step_avg:96.80ms -step:135/1695 train_time:13064ms step_avg:96.77ms -step:136/1695 train_time:13157ms step_avg:96.75ms -step:137/1695 train_time:13252ms step_avg:96.73ms -step:138/1695 train_time:13348ms step_avg:96.73ms -step:139/1695 train_time:13444ms step_avg:96.72ms -step:140/1695 train_time:13539ms step_avg:96.70ms -step:141/1695 train_time:13633ms step_avg:96.69ms -step:142/1695 train_time:13727ms step_avg:96.67ms -step:143/1695 train_time:13821ms step_avg:96.65ms -step:144/1695 train_time:13916ms step_avg:96.64ms -step:145/1695 train_time:14010ms step_avg:96.62ms -step:146/1695 train_time:14103ms step_avg:96.59ms -step:147/1695 train_time:14197ms step_avg:96.58ms -step:148/1695 train_time:14293ms step_avg:96.57ms -step:149/1695 train_time:14387ms step_avg:96.56ms -step:150/1695 train_time:14482ms step_avg:96.55ms -step:151/1695 train_time:14578ms step_avg:96.55ms -step:152/1695 train_time:14673ms step_avg:96.53ms -step:153/1695 train_time:14766ms step_avg:96.51ms -step:154/1695 train_time:14860ms step_avg:96.49ms -step:155/1695 train_time:14953ms step_avg:96.47ms -step:156/1695 train_time:15046ms step_avg:96.45ms -step:157/1695 train_time:15140ms step_avg:96.43ms -step:158/1695 train_time:15233ms step_avg:96.41ms -step:159/1695 train_time:15327ms step_avg:96.40ms -step:160/1695 train_time:15422ms step_avg:96.39ms -step:161/1695 train_time:15517ms step_avg:96.38ms -step:162/1695 train_time:15613ms step_avg:96.38ms -step:163/1695 train_time:15706ms step_avg:96.36ms -step:164/1695 train_time:15800ms step_avg:96.34ms -step:165/1695 train_time:15893ms step_avg:96.32ms -step:166/1695 train_time:15987ms step_avg:96.30ms -step:167/1695 train_time:16080ms step_avg:96.29ms -step:168/1695 train_time:16174ms step_avg:96.27ms -step:169/1695 train_time:16267ms step_avg:96.26ms -step:170/1695 train_time:16361ms step_avg:96.24ms -step:171/1695 train_time:16455ms step_avg:96.23ms -step:172/1695 train_time:16551ms step_avg:96.22ms -step:173/1695 train_time:16932ms step_avg:97.87ms -step:174/1695 train_time:17007ms step_avg:97.74ms -step:175/1695 train_time:17099ms step_avg:97.71ms -step:176/1695 train_time:17192ms step_avg:97.68ms -step:177/1695 train_time:17285ms step_avg:97.65ms -step:178/1695 train_time:17378ms step_avg:97.63ms -step:179/1695 train_time:17472ms step_avg:97.61ms -step:180/1695 train_time:17565ms step_avg:97.58ms -step:181/1695 train_time:17658ms step_avg:97.56ms -step:182/1695 train_time:17752ms step_avg:97.54ms -step:183/1695 train_time:17847ms step_avg:97.53ms -step:184/1695 train_time:17944ms step_avg:97.52ms -step:185/1695 train_time:18041ms step_avg:97.52ms -step:186/1695 train_time:18135ms step_avg:97.50ms -step:187/1695 train_time:18229ms step_avg:97.48ms -step:188/1695 train_time:18322ms step_avg:97.46ms -step:189/1695 train_time:18416ms step_avg:97.44ms -step:190/1695 train_time:18510ms step_avg:97.42ms -step:191/1695 train_time:18602ms step_avg:97.40ms -step:192/1695 train_time:18696ms step_avg:97.38ms -step:193/1695 train_time:18790ms step_avg:97.36ms -step:194/1695 train_time:18885ms step_avg:97.34ms -step:195/1695 train_time:18980ms step_avg:97.33ms -step:196/1695 train_time:19075ms step_avg:97.32ms -step:197/1695 train_time:19169ms step_avg:97.31ms -step:198/1695 train_time:19263ms step_avg:97.29ms -step:199/1695 train_time:19357ms step_avg:97.27ms -step:200/1695 train_time:19451ms step_avg:97.26ms -step:201/1695 train_time:19544ms step_avg:97.23ms -step:202/1695 train_time:19638ms step_avg:97.22ms -step:203/1695 train_time:19733ms step_avg:97.21ms -step:204/1695 train_time:19826ms step_avg:97.19ms -step:205/1695 train_time:19921ms step_avg:97.18ms -step:206/1695 train_time:20016ms step_avg:97.17ms -step:207/1695 train_time:20111ms step_avg:97.15ms -step:208/1695 train_time:20204ms step_avg:97.14ms -step:209/1695 train_time:20299ms step_avg:97.12ms -step:210/1695 train_time:20393ms step_avg:97.11ms -step:211/1695 train_time:20487ms step_avg:97.10ms -step:212/1695 train_time:20581ms step_avg:97.08ms -step:213/1695 train_time:20674ms step_avg:97.06ms -step:214/1695 train_time:20768ms step_avg:97.05ms -step:215/1695 train_time:20862ms step_avg:97.03ms -step:216/1695 train_time:20957ms step_avg:97.02ms -step:217/1695 train_time:21051ms step_avg:97.01ms -step:218/1695 train_time:21145ms step_avg:97.00ms -step:219/1695 train_time:21240ms step_avg:96.98ms -step:220/1695 train_time:21334ms step_avg:96.97ms -step:221/1695 train_time:21428ms step_avg:96.96ms -step:222/1695 train_time:21521ms step_avg:96.94ms -step:223/1695 train_time:21616ms step_avg:96.93ms -step:224/1695 train_time:21710ms step_avg:96.92ms -step:225/1695 train_time:21804ms step_avg:96.91ms -step:226/1695 train_time:21899ms step_avg:96.90ms -step:227/1695 train_time:21994ms step_avg:96.89ms -step:228/1695 train_time:22087ms step_avg:96.87ms -step:229/1695 train_time:22181ms step_avg:96.86ms -step:230/1695 train_time:22276ms step_avg:96.85ms -step:231/1695 train_time:22371ms step_avg:96.85ms -step:232/1695 train_time:22465ms step_avg:96.83ms -step:233/1695 train_time:22559ms step_avg:96.82ms -step:234/1695 train_time:22653ms step_avg:96.81ms -step:235/1695 train_time:22747ms step_avg:96.79ms -step:236/1695 train_time:22840ms step_avg:96.78ms -step:237/1695 train_time:22935ms step_avg:96.77ms -step:238/1695 train_time:23029ms step_avg:96.76ms -step:239/1695 train_time:23122ms step_avg:96.75ms -step:240/1695 train_time:23217ms step_avg:96.74ms -step:241/1695 train_time:23312ms step_avg:96.73ms -step:242/1695 train_time:23405ms step_avg:96.72ms -step:243/1695 train_time:23499ms step_avg:96.71ms -step:244/1695 train_time:23593ms step_avg:96.69ms -step:245/1695 train_time:23686ms step_avg:96.68ms -step:246/1695 train_time:23780ms step_avg:96.67ms -step:247/1695 train_time:23875ms step_avg:96.66ms -step:248/1695 train_time:23969ms step_avg:96.65ms -step:249/1695 train_time:24063ms step_avg:96.64ms -step:250/1695 train_time:24157ms step_avg:96.63ms -step:250/1695 val_loss:3.9654 train_time:24251ms step_avg:97.00ms -step:251/1695 train_time:24276ms step_avg:96.72ms -step:252/1695 train_time:24355ms step_avg:96.65ms -step:253/1695 train_time:24453ms step_avg:96.65ms -step:254/1695 train_time:24548ms step_avg:96.65ms -step:255/1695 train_time:24642ms step_avg:96.64ms -step:256/1695 train_time:24735ms step_avg:96.62ms -step:257/1695 train_time:24828ms step_avg:96.61ms -step:258/1695 train_time:24921ms step_avg:96.59ms -step:259/1695 train_time:25014ms step_avg:96.58ms -step:260/1695 train_time:25108ms step_avg:96.57ms -step:261/1695 train_time:25201ms step_avg:96.55ms -step:262/1695 train_time:25295ms step_avg:96.55ms -step:263/1695 train_time:25391ms step_avg:96.54ms -step:264/1695 train_time:25487ms step_avg:96.54ms -step:265/1695 train_time:25582ms step_avg:96.54ms -step:266/1695 train_time:25676ms step_avg:96.52ms -step:267/1695 train_time:25769ms step_avg:96.51ms -step:268/1695 train_time:25862ms step_avg:96.50ms -step:269/1695 train_time:25955ms step_avg:96.49ms -step:270/1695 train_time:26048ms step_avg:96.47ms -step:271/1695 train_time:26141ms step_avg:96.46ms -step:272/1695 train_time:26235ms step_avg:96.45ms -step:273/1695 train_time:26330ms step_avg:96.45ms -step:274/1695 train_time:26426ms step_avg:96.45ms -step:275/1695 train_time:26522ms step_avg:96.44ms -step:276/1695 train_time:26616ms step_avg:96.44ms -step:277/1695 train_time:26710ms step_avg:96.43ms -step:278/1695 train_time:26803ms step_avg:96.42ms -step:279/1695 train_time:26897ms step_avg:96.40ms -step:280/1695 train_time:26990ms step_avg:96.39ms -step:281/1695 train_time:27083ms step_avg:96.38ms -step:282/1695 train_time:27176ms step_avg:96.37ms -step:283/1695 train_time:27270ms step_avg:96.36ms -step:284/1695 train_time:27365ms step_avg:96.36ms -step:285/1695 train_time:27459ms step_avg:96.35ms -step:286/1695 train_time:27553ms step_avg:96.34ms -step:287/1695 train_time:27648ms step_avg:96.34ms -step:288/1695 train_time:27742ms step_avg:96.33ms -step:289/1695 train_time:27836ms step_avg:96.32ms -step:290/1695 train_time:27930ms step_avg:96.31ms -step:291/1695 train_time:28024ms step_avg:96.30ms -step:292/1695 train_time:28118ms step_avg:96.29ms -step:293/1695 train_time:28211ms step_avg:96.28ms -step:294/1695 train_time:28305ms step_avg:96.28ms -step:295/1695 train_time:28399ms step_avg:96.27ms -step:296/1695 train_time:28493ms step_avg:96.26ms -step:297/1695 train_time:28587ms step_avg:96.25ms -step:298/1695 train_time:28682ms step_avg:96.25ms -step:299/1695 train_time:28776ms step_avg:96.24ms -step:300/1695 train_time:28870ms step_avg:96.23ms -step:301/1695 train_time:28964ms step_avg:96.23ms -step:302/1695 train_time:29057ms step_avg:96.22ms -step:303/1695 train_time:29151ms step_avg:96.21ms -step:304/1695 train_time:29245ms step_avg:96.20ms -step:305/1695 train_time:29340ms step_avg:96.20ms -step:306/1695 train_time:29434ms step_avg:96.19ms -step:307/1695 train_time:29528ms step_avg:96.18ms -step:308/1695 train_time:29623ms step_avg:96.18ms -step:309/1695 train_time:29717ms step_avg:96.17ms -step:310/1695 train_time:29812ms step_avg:96.17ms -step:311/1695 train_time:29907ms step_avg:96.16ms -step:312/1695 train_time:30001ms step_avg:96.16ms -step:313/1695 train_time:30094ms step_avg:96.15ms -step:314/1695 train_time:30188ms step_avg:96.14ms -step:315/1695 train_time:30282ms step_avg:96.13ms -step:316/1695 train_time:30375ms step_avg:96.12ms -step:317/1695 train_time:30469ms step_avg:96.12ms -step:318/1695 train_time:30563ms step_avg:96.11ms -step:319/1695 train_time:30658ms step_avg:96.11ms -step:320/1695 train_time:30751ms step_avg:96.10ms -step:321/1695 train_time:30846ms step_avg:96.09ms -step:322/1695 train_time:30941ms step_avg:96.09ms -step:323/1695 train_time:31034ms step_avg:96.08ms -step:324/1695 train_time:31128ms step_avg:96.07ms -step:325/1695 train_time:31223ms step_avg:96.07ms -step:326/1695 train_time:31317ms step_avg:96.06ms -step:327/1695 train_time:31410ms step_avg:96.06ms -step:328/1695 train_time:31504ms step_avg:96.05ms -step:329/1695 train_time:31597ms step_avg:96.04ms -step:330/1695 train_time:31691ms step_avg:96.03ms -step:331/1695 train_time:31785ms step_avg:96.03ms -step:332/1695 train_time:31881ms step_avg:96.03ms -step:333/1695 train_time:31974ms step_avg:96.02ms -step:334/1695 train_time:32069ms step_avg:96.01ms -step:335/1695 train_time:32163ms step_avg:96.01ms -step:336/1695 train_time:32257ms step_avg:96.00ms -step:337/1695 train_time:32350ms step_avg:95.99ms -step:338/1695 train_time:32446ms step_avg:95.99ms -step:339/1695 train_time:32541ms step_avg:95.99ms -step:340/1695 train_time:32634ms step_avg:95.98ms -step:341/1695 train_time:32728ms step_avg:95.98ms -step:342/1695 train_time:32823ms step_avg:95.97ms -step:343/1695 train_time:32916ms step_avg:95.97ms -step:344/1695 train_time:33010ms step_avg:95.96ms -step:345/1695 train_time:33345ms step_avg:96.65ms -step:346/1695 train_time:33427ms step_avg:96.61ms -step:347/1695 train_time:33519ms step_avg:96.60ms -step:348/1695 train_time:33612ms step_avg:96.59ms -step:349/1695 train_time:33705ms step_avg:96.58ms -step:350/1695 train_time:33798ms step_avg:96.57ms -step:351/1695 train_time:33891ms step_avg:96.55ms -step:352/1695 train_time:33984ms step_avg:96.54ms -step:353/1695 train_time:34077ms step_avg:96.53ms -step:354/1695 train_time:34169ms step_avg:96.52ms -step:355/1695 train_time:34266ms step_avg:96.52ms -step:356/1695 train_time:34363ms step_avg:96.52ms -step:357/1695 train_time:34458ms step_avg:96.52ms -step:358/1695 train_time:34552ms step_avg:96.51ms -step:359/1695 train_time:34646ms step_avg:96.51ms -step:360/1695 train_time:34741ms step_avg:96.50ms -step:361/1695 train_time:34834ms step_avg:96.49ms -step:362/1695 train_time:34927ms step_avg:96.48ms -step:363/1695 train_time:35020ms step_avg:96.47ms -step:364/1695 train_time:35113ms step_avg:96.46ms -step:365/1695 train_time:35207ms step_avg:96.46ms -step:366/1695 train_time:35303ms step_avg:96.45ms -step:367/1695 train_time:35398ms step_avg:96.45ms -step:368/1695 train_time:35492ms step_avg:96.45ms -step:369/1695 train_time:35587ms step_avg:96.44ms -step:370/1695 train_time:35681ms step_avg:96.44ms -step:371/1695 train_time:35775ms step_avg:96.43ms -step:372/1695 train_time:35868ms step_avg:96.42ms -step:373/1695 train_time:35961ms step_avg:96.41ms -step:374/1695 train_time:36054ms step_avg:96.40ms -step:375/1695 train_time:36147ms step_avg:96.39ms -step:375/1695 val_loss:3.8151 train_time:36240ms step_avg:96.64ms -step:376/1695 train_time:36265ms step_avg:96.45ms -step:377/1695 train_time:36343ms step_avg:96.40ms -step:378/1695 train_time:36441ms step_avg:96.41ms -step:379/1695 train_time:36535ms step_avg:96.40ms -step:380/1695 train_time:36628ms step_avg:96.39ms -step:381/1695 train_time:36721ms step_avg:96.38ms -step:382/1695 train_time:36814ms step_avg:96.37ms -step:383/1695 train_time:36907ms step_avg:96.36ms -step:384/1695 train_time:36999ms step_avg:96.35ms -step:385/1695 train_time:37092ms step_avg:96.34ms -step:386/1695 train_time:37185ms step_avg:96.33ms -step:387/1695 train_time:37280ms step_avg:96.33ms -step:388/1695 train_time:37376ms step_avg:96.33ms -step:389/1695 train_time:37474ms step_avg:96.33ms -step:390/1695 train_time:37570ms step_avg:96.33ms -step:391/1695 train_time:37665ms step_avg:96.33ms -step:392/1695 train_time:37758ms step_avg:96.32ms -step:393/1695 train_time:37851ms step_avg:96.31ms -step:394/1695 train_time:37945ms step_avg:96.31ms -step:395/1695 train_time:38038ms step_avg:96.30ms -step:396/1695 train_time:38132ms step_avg:96.29ms -step:397/1695 train_time:38226ms step_avg:96.29ms -step:398/1695 train_time:38320ms step_avg:96.28ms -step:399/1695 train_time:38414ms step_avg:96.28ms -step:400/1695 train_time:38510ms step_avg:96.27ms -step:401/1695 train_time:38606ms step_avg:96.27ms -step:402/1695 train_time:38700ms step_avg:96.27ms -step:403/1695 train_time:38793ms step_avg:96.26ms -step:404/1695 train_time:38888ms step_avg:96.26ms -step:405/1695 train_time:38981ms step_avg:96.25ms -step:406/1695 train_time:39074ms step_avg:96.24ms -step:407/1695 train_time:39168ms step_avg:96.24ms -step:408/1695 train_time:39262ms step_avg:96.23ms -step:409/1695 train_time:39356ms step_avg:96.22ms -step:410/1695 train_time:39450ms step_avg:96.22ms -step:411/1695 train_time:39545ms step_avg:96.22ms -step:412/1695 train_time:39639ms step_avg:96.21ms -step:413/1695 train_time:39733ms step_avg:96.21ms -step:414/1695 train_time:39827ms step_avg:96.20ms -step:415/1695 train_time:39920ms step_avg:96.19ms -step:416/1695 train_time:40013ms step_avg:96.19ms -step:417/1695 train_time:40107ms step_avg:96.18ms -step:418/1695 train_time:40200ms step_avg:96.17ms -step:419/1695 train_time:40294ms step_avg:96.17ms -step:420/1695 train_time:40388ms step_avg:96.16ms -step:421/1695 train_time:40482ms step_avg:96.16ms -step:422/1695 train_time:40577ms step_avg:96.15ms -step:423/1695 train_time:40672ms step_avg:96.15ms -step:424/1695 train_time:40767ms step_avg:96.15ms -step:425/1695 train_time:40862ms step_avg:96.15ms -step:426/1695 train_time:40955ms step_avg:96.14ms -step:427/1695 train_time:41049ms step_avg:96.13ms -step:428/1695 train_time:41143ms step_avg:96.13ms -step:429/1695 train_time:41236ms step_avg:96.12ms -step:430/1695 train_time:41330ms step_avg:96.12ms -step:431/1695 train_time:41424ms step_avg:96.11ms -step:432/1695 train_time:41518ms step_avg:96.11ms -step:433/1695 train_time:41612ms step_avg:96.10ms -step:434/1695 train_time:41707ms step_avg:96.10ms -step:435/1695 train_time:41800ms step_avg:96.09ms -step:436/1695 train_time:41894ms step_avg:96.09ms -step:437/1695 train_time:41988ms step_avg:96.08ms -step:438/1695 train_time:42081ms step_avg:96.08ms -step:439/1695 train_time:42175ms step_avg:96.07ms -step:440/1695 train_time:42269ms step_avg:96.07ms -step:441/1695 train_time:42365ms step_avg:96.06ms -step:442/1695 train_time:42458ms step_avg:96.06ms -step:443/1695 train_time:42553ms step_avg:96.06ms -step:444/1695 train_time:42647ms step_avg:96.05ms -step:445/1695 train_time:42741ms step_avg:96.05ms -step:446/1695 train_time:42835ms step_avg:96.04ms -step:447/1695 train_time:42929ms step_avg:96.04ms -step:448/1695 train_time:43023ms step_avg:96.03ms -step:449/1695 train_time:43116ms step_avg:96.03ms -step:450/1695 train_time:43210ms step_avg:96.02ms -step:451/1695 train_time:43305ms step_avg:96.02ms -step:452/1695 train_time:43398ms step_avg:96.01ms -step:453/1695 train_time:43492ms step_avg:96.01ms -step:454/1695 train_time:43587ms step_avg:96.01ms -step:455/1695 train_time:43680ms step_avg:96.00ms -step:456/1695 train_time:43774ms step_avg:96.00ms -step:457/1695 train_time:43869ms step_avg:95.99ms -step:458/1695 train_time:43964ms step_avg:95.99ms -step:459/1695 train_time:44057ms step_avg:95.99ms -step:460/1695 train_time:44151ms step_avg:95.98ms -step:461/1695 train_time:44245ms step_avg:95.98ms -step:462/1695 train_time:44338ms step_avg:95.97ms -step:463/1695 train_time:44432ms step_avg:95.97ms -step:464/1695 train_time:44526ms step_avg:95.96ms -step:465/1695 train_time:44620ms step_avg:95.96ms -step:466/1695 train_time:44714ms step_avg:95.95ms -step:467/1695 train_time:44808ms step_avg:95.95ms -step:468/1695 train_time:44902ms step_avg:95.95ms -step:469/1695 train_time:44995ms step_avg:95.94ms -step:470/1695 train_time:45090ms step_avg:95.94ms -step:471/1695 train_time:45184ms step_avg:95.93ms -step:472/1695 train_time:45277ms step_avg:95.93ms -step:473/1695 train_time:45371ms step_avg:95.92ms -step:474/1695 train_time:45466ms step_avg:95.92ms -step:475/1695 train_time:45559ms step_avg:95.91ms -step:476/1695 train_time:45653ms step_avg:95.91ms -step:477/1695 train_time:45748ms step_avg:95.91ms -step:478/1695 train_time:45843ms step_avg:95.91ms -step:479/1695 train_time:45936ms step_avg:95.90ms -step:480/1695 train_time:46031ms step_avg:95.90ms -step:481/1695 train_time:46125ms step_avg:95.89ms -step:482/1695 train_time:46218ms step_avg:95.89ms -step:483/1695 train_time:46312ms step_avg:95.88ms -step:484/1695 train_time:46406ms step_avg:95.88ms -step:485/1695 train_time:46499ms step_avg:95.87ms -step:486/1695 train_time:46593ms step_avg:95.87ms -step:487/1695 train_time:46687ms step_avg:95.87ms -step:488/1695 train_time:46780ms step_avg:95.86ms -step:489/1695 train_time:46874ms step_avg:95.86ms -step:490/1695 train_time:46969ms step_avg:95.86ms -step:491/1695 train_time:47064ms step_avg:95.85ms -step:492/1695 train_time:47157ms step_avg:95.85ms -step:493/1695 train_time:47251ms step_avg:95.84ms -step:494/1695 train_time:47345ms step_avg:95.84ms -step:495/1695 train_time:47439ms step_avg:95.84ms -step:496/1695 train_time:47533ms step_avg:95.83ms -step:497/1695 train_time:47628ms step_avg:95.83ms -step:498/1695 train_time:47722ms step_avg:95.83ms -step:499/1695 train_time:47814ms step_avg:95.82ms -step:500/1695 train_time:47908ms step_avg:95.82ms -step:500/1695 val_loss:3.7156 train_time:48001ms step_avg:96.00ms -step:501/1695 train_time:48026ms step_avg:95.86ms -step:502/1695 train_time:48105ms step_avg:95.83ms -step:503/1695 train_time:48205ms step_avg:95.83ms -step:504/1695 train_time:48300ms step_avg:95.83ms -step:505/1695 train_time:48393ms step_avg:95.83ms -step:506/1695 train_time:48486ms step_avg:95.82ms -step:507/1695 train_time:48579ms step_avg:95.82ms -step:508/1695 train_time:48673ms step_avg:95.81ms -step:509/1695 train_time:48765ms step_avg:95.81ms -step:510/1695 train_time:48859ms step_avg:95.80ms -step:511/1695 train_time:48952ms step_avg:95.80ms -step:512/1695 train_time:49046ms step_avg:95.79ms -step:513/1695 train_time:49142ms step_avg:95.79ms -step:514/1695 train_time:49239ms step_avg:95.80ms -step:515/1695 train_time:49334ms step_avg:95.79ms -step:516/1695 train_time:49427ms step_avg:95.79ms -step:517/1695 train_time:49520ms step_avg:95.78ms -step:518/1695 train_time:49614ms step_avg:95.78ms -step:519/1695 train_time:50068ms step_avg:96.47ms -step:520/1695 train_time:50139ms step_avg:96.42ms -step:521/1695 train_time:50231ms step_avg:96.41ms -step:522/1695 train_time:50324ms step_avg:96.41ms -step:523/1695 train_time:50417ms step_avg:96.40ms -step:524/1695 train_time:50510ms step_avg:96.39ms -step:525/1695 train_time:50602ms step_avg:96.39ms -step:526/1695 train_time:50695ms step_avg:96.38ms -step:527/1695 train_time:50789ms step_avg:96.37ms -step:528/1695 train_time:50882ms step_avg:96.37ms -step:529/1695 train_time:50977ms step_avg:96.37ms -step:530/1695 train_time:51074ms step_avg:96.37ms -step:531/1695 train_time:51171ms step_avg:96.37ms -step:532/1695 train_time:51264ms step_avg:96.36ms -step:533/1695 train_time:51358ms step_avg:96.36ms -step:534/1695 train_time:51451ms step_avg:96.35ms -step:535/1695 train_time:51543ms step_avg:96.34ms -step:536/1695 train_time:51636ms step_avg:96.34ms -step:537/1695 train_time:51730ms step_avg:96.33ms -step:538/1695 train_time:51823ms step_avg:96.32ms -step:539/1695 train_time:51916ms step_avg:96.32ms -step:540/1695 train_time:52011ms step_avg:96.32ms -step:541/1695 train_time:52105ms step_avg:96.31ms -step:542/1695 train_time:52200ms step_avg:96.31ms -step:543/1695 train_time:52295ms step_avg:96.31ms -step:544/1695 train_time:52389ms step_avg:96.30ms -step:545/1695 train_time:52482ms step_avg:96.30ms -step:546/1695 train_time:52577ms step_avg:96.29ms -step:547/1695 train_time:52671ms step_avg:96.29ms -step:548/1695 train_time:52764ms step_avg:96.28ms -step:549/1695 train_time:52858ms step_avg:96.28ms -step:550/1695 train_time:52951ms step_avg:96.27ms -step:551/1695 train_time:53045ms step_avg:96.27ms -step:552/1695 train_time:53139ms step_avg:96.27ms -step:553/1695 train_time:53233ms step_avg:96.26ms -step:554/1695 train_time:53326ms step_avg:96.26ms -step:555/1695 train_time:53420ms step_avg:96.25ms -step:556/1695 train_time:53514ms step_avg:96.25ms -step:557/1695 train_time:53608ms step_avg:96.24ms -step:558/1695 train_time:53701ms step_avg:96.24ms -step:559/1695 train_time:53795ms step_avg:96.24ms -step:560/1695 train_time:53890ms step_avg:96.23ms -step:561/1695 train_time:53983ms step_avg:96.23ms -step:562/1695 train_time:54077ms step_avg:96.22ms -step:563/1695 train_time:54172ms step_avg:96.22ms -step:564/1695 train_time:54265ms step_avg:96.21ms -step:565/1695 train_time:54359ms step_avg:96.21ms -step:566/1695 train_time:54453ms step_avg:96.21ms -step:567/1695 train_time:54546ms step_avg:96.20ms -step:568/1695 train_time:54641ms step_avg:96.20ms -step:569/1695 train_time:54737ms step_avg:96.20ms -step:570/1695 train_time:54833ms step_avg:96.20ms -step:571/1695 train_time:54929ms step_avg:96.20ms -step:572/1695 train_time:55024ms step_avg:96.20ms -step:573/1695 train_time:55120ms step_avg:96.20ms -step:574/1695 train_time:55216ms step_avg:96.20ms -step:575/1695 train_time:55312ms step_avg:96.19ms -step:576/1695 train_time:55407ms step_avg:96.19ms -step:577/1695 train_time:55502ms step_avg:96.19ms -step:578/1695 train_time:55599ms step_avg:96.19ms -step:579/1695 train_time:55695ms step_avg:96.19ms -step:580/1695 train_time:55792ms step_avg:96.19ms -step:581/1695 train_time:55887ms step_avg:96.19ms -step:582/1695 train_time:55983ms step_avg:96.19ms -step:583/1695 train_time:56080ms step_avg:96.19ms -step:584/1695 train_time:56177ms step_avg:96.19ms -step:585/1695 train_time:56274ms step_avg:96.19ms -step:586/1695 train_time:56371ms step_avg:96.20ms -step:587/1695 train_time:56467ms step_avg:96.20ms -step:588/1695 train_time:56562ms step_avg:96.19ms -step:589/1695 train_time:56659ms step_avg:96.19ms -step:590/1695 train_time:56755ms step_avg:96.20ms -step:591/1695 train_time:56851ms step_avg:96.20ms -step:592/1695 train_time:56946ms step_avg:96.19ms -step:593/1695 train_time:57041ms step_avg:96.19ms -step:594/1695 train_time:57138ms step_avg:96.19ms -step:595/1695 train_time:57234ms step_avg:96.19ms -step:596/1695 train_time:57332ms step_avg:96.19ms -step:597/1695 train_time:57428ms step_avg:96.19ms -step:598/1695 train_time:57523ms step_avg:96.19ms -step:599/1695 train_time:57619ms step_avg:96.19ms -step:600/1695 train_time:57716ms step_avg:96.19ms -step:601/1695 train_time:57812ms step_avg:96.19ms -step:602/1695 train_time:57906ms step_avg:96.19ms -step:603/1695 train_time:58002ms step_avg:96.19ms -step:604/1695 train_time:58098ms step_avg:96.19ms -step:605/1695 train_time:58194ms step_avg:96.19ms -step:606/1695 train_time:58290ms step_avg:96.19ms -step:607/1695 train_time:58386ms step_avg:96.19ms -step:608/1695 train_time:58482ms step_avg:96.19ms -step:609/1695 train_time:58579ms step_avg:96.19ms -step:610/1695 train_time:58676ms step_avg:96.19ms -step:611/1695 train_time:58773ms step_avg:96.19ms -step:612/1695 train_time:58868ms step_avg:96.19ms -step:613/1695 train_time:58963ms step_avg:96.19ms -step:614/1695 train_time:59059ms step_avg:96.19ms -step:615/1695 train_time:59156ms step_avg:96.19ms -step:616/1695 train_time:59252ms step_avg:96.19ms -step:617/1695 train_time:59348ms step_avg:96.19ms -step:618/1695 train_time:59444ms step_avg:96.19ms -step:619/1695 train_time:59540ms step_avg:96.19ms -step:620/1695 train_time:59636ms step_avg:96.19ms -step:621/1695 train_time:59732ms step_avg:96.19ms -step:622/1695 train_time:59828ms step_avg:96.19ms -step:623/1695 train_time:59923ms step_avg:96.18ms -step:624/1695 train_time:60019ms step_avg:96.18ms -step:625/1695 train_time:60116ms step_avg:96.19ms -step:625/1695 val_loss:3.6216 train_time:60211ms step_avg:96.34ms -step:626/1695 train_time:60235ms step_avg:96.22ms -step:627/1695 train_time:60317ms step_avg:96.20ms -step:628/1695 train_time:60413ms step_avg:96.20ms -step:629/1695 train_time:60508ms step_avg:96.20ms -step:630/1695 train_time:60603ms step_avg:96.19ms -step:631/1695 train_time:60697ms step_avg:96.19ms -step:632/1695 train_time:60792ms step_avg:96.19ms -step:633/1695 train_time:60888ms step_avg:96.19ms -step:634/1695 train_time:60982ms step_avg:96.19ms -step:635/1695 train_time:61078ms step_avg:96.19ms -step:636/1695 train_time:61178ms step_avg:96.19ms -step:637/1695 train_time:61278ms step_avg:96.20ms -step:638/1695 train_time:61377ms step_avg:96.20ms -step:639/1695 train_time:61475ms step_avg:96.20ms -step:640/1695 train_time:61572ms step_avg:96.21ms -step:641/1695 train_time:61668ms step_avg:96.21ms -step:642/1695 train_time:61763ms step_avg:96.20ms -step:643/1695 train_time:61858ms step_avg:96.20ms -step:644/1695 train_time:61954ms step_avg:96.20ms -step:645/1695 train_time:62049ms step_avg:96.20ms -step:646/1695 train_time:62144ms step_avg:96.20ms -step:647/1695 train_time:62241ms step_avg:96.20ms -step:648/1695 train_time:62339ms step_avg:96.20ms -step:649/1695 train_time:62436ms step_avg:96.20ms -step:650/1695 train_time:62533ms step_avg:96.21ms -step:651/1695 train_time:62630ms step_avg:96.21ms -step:652/1695 train_time:62725ms step_avg:96.20ms -step:653/1695 train_time:62820ms step_avg:96.20ms -step:654/1695 train_time:62916ms step_avg:96.20ms -step:655/1695 train_time:63011ms step_avg:96.20ms -step:656/1695 train_time:63107ms step_avg:96.20ms -step:657/1695 train_time:63202ms step_avg:96.20ms -step:658/1695 train_time:63298ms step_avg:96.20ms -step:659/1695 train_time:63396ms step_avg:96.20ms -step:660/1695 train_time:63494ms step_avg:96.20ms -step:661/1695 train_time:63591ms step_avg:96.20ms -step:662/1695 train_time:63688ms step_avg:96.21ms -step:663/1695 train_time:63784ms step_avg:96.20ms -step:664/1695 train_time:63879ms step_avg:96.20ms -step:665/1695 train_time:63974ms step_avg:96.20ms -step:666/1695 train_time:64070ms step_avg:96.20ms -step:667/1695 train_time:64165ms step_avg:96.20ms -step:668/1695 train_time:64260ms step_avg:96.20ms -step:669/1695 train_time:64356ms step_avg:96.20ms -step:670/1695 train_time:64453ms step_avg:96.20ms -step:671/1695 train_time:64550ms step_avg:96.20ms -step:672/1695 train_time:64647ms step_avg:96.20ms -step:673/1695 train_time:64742ms step_avg:96.20ms -step:674/1695 train_time:64838ms step_avg:96.20ms -step:675/1695 train_time:64934ms step_avg:96.20ms -step:676/1695 train_time:65030ms step_avg:96.20ms -step:677/1695 train_time:65126ms step_avg:96.20ms -step:678/1695 train_time:65222ms step_avg:96.20ms -step:679/1695 train_time:65318ms step_avg:96.20ms -step:680/1695 train_time:65414ms step_avg:96.20ms -step:681/1695 train_time:65510ms step_avg:96.20ms -step:682/1695 train_time:65606ms step_avg:96.20ms -step:683/1695 train_time:65701ms step_avg:96.20ms -step:684/1695 train_time:65798ms step_avg:96.20ms -step:685/1695 train_time:65894ms step_avg:96.20ms -step:686/1695 train_time:65990ms step_avg:96.20ms -step:687/1695 train_time:66086ms step_avg:96.20ms -step:688/1695 train_time:66181ms step_avg:96.19ms -step:689/1695 train_time:66277ms step_avg:96.19ms -step:690/1695 train_time:66373ms step_avg:96.19ms -step:691/1695 train_time:66817ms step_avg:96.70ms -step:692/1695 train_time:66898ms step_avg:96.67ms -step:693/1695 train_time:66992ms step_avg:96.67ms -step:694/1695 train_time:67087ms step_avg:96.67ms -step:695/1695 train_time:67181ms step_avg:96.66ms -step:696/1695 train_time:67277ms step_avg:96.66ms -step:697/1695 train_time:67371ms step_avg:96.66ms -step:698/1695 train_time:67466ms step_avg:96.66ms -step:699/1695 train_time:67560ms step_avg:96.65ms -step:700/1695 train_time:67656ms step_avg:96.65ms -step:701/1695 train_time:67756ms step_avg:96.66ms -step:702/1695 train_time:67856ms step_avg:96.66ms -step:703/1695 train_time:67953ms step_avg:96.66ms -step:704/1695 train_time:68049ms step_avg:96.66ms -step:705/1695 train_time:68144ms step_avg:96.66ms -step:706/1695 train_time:68238ms step_avg:96.65ms -step:707/1695 train_time:68334ms step_avg:96.65ms -step:708/1695 train_time:68430ms step_avg:96.65ms -step:709/1695 train_time:68524ms step_avg:96.65ms -step:710/1695 train_time:68618ms step_avg:96.65ms -step:711/1695 train_time:68715ms step_avg:96.65ms -step:712/1695 train_time:68813ms step_avg:96.65ms -step:713/1695 train_time:68909ms step_avg:96.65ms -step:714/1695 train_time:69005ms step_avg:96.65ms -step:715/1695 train_time:69101ms step_avg:96.64ms -step:716/1695 train_time:69196ms step_avg:96.64ms -step:717/1695 train_time:69293ms step_avg:96.64ms -step:718/1695 train_time:69390ms step_avg:96.64ms -step:719/1695 train_time:69485ms step_avg:96.64ms -step:720/1695 train_time:69580ms step_avg:96.64ms -step:721/1695 train_time:69676ms step_avg:96.64ms -step:722/1695 train_time:69774ms step_avg:96.64ms -step:723/1695 train_time:69871ms step_avg:96.64ms -step:724/1695 train_time:69969ms step_avg:96.64ms -step:725/1695 train_time:70065ms step_avg:96.64ms -step:726/1695 train_time:70160ms step_avg:96.64ms -step:727/1695 train_time:70256ms step_avg:96.64ms -step:728/1695 train_time:70353ms step_avg:96.64ms -step:729/1695 train_time:70450ms step_avg:96.64ms -step:730/1695 train_time:70547ms step_avg:96.64ms -step:731/1695 train_time:70642ms step_avg:96.64ms -step:732/1695 train_time:70737ms step_avg:96.64ms -step:733/1695 train_time:70834ms step_avg:96.64ms -step:734/1695 train_time:70933ms step_avg:96.64ms -step:735/1695 train_time:71029ms step_avg:96.64ms -step:736/1695 train_time:71125ms step_avg:96.64ms -step:737/1695 train_time:71220ms step_avg:96.63ms -step:738/1695 train_time:71316ms step_avg:96.63ms -step:739/1695 train_time:71413ms step_avg:96.63ms -step:740/1695 train_time:71508ms step_avg:96.63ms -step:741/1695 train_time:71603ms step_avg:96.63ms -step:742/1695 train_time:71698ms step_avg:96.63ms -step:743/1695 train_time:71794ms step_avg:96.63ms -step:744/1695 train_time:71891ms step_avg:96.63ms -step:745/1695 train_time:71988ms step_avg:96.63ms -step:746/1695 train_time:72083ms step_avg:96.63ms -step:747/1695 train_time:72179ms step_avg:96.63ms -step:748/1695 train_time:72275ms step_avg:96.62ms -step:749/1695 train_time:72371ms step_avg:96.62ms -step:750/1695 train_time:72467ms step_avg:96.62ms -step:750/1695 val_loss:3.5657 train_time:72560ms step_avg:96.75ms -step:751/1695 train_time:72585ms step_avg:96.65ms -step:752/1695 train_time:72667ms step_avg:96.63ms -step:753/1695 train_time:72765ms step_avg:96.63ms -step:754/1695 train_time:72860ms step_avg:96.63ms -step:755/1695 train_time:72956ms step_avg:96.63ms -step:756/1695 train_time:73050ms step_avg:96.63ms -step:757/1695 train_time:73145ms step_avg:96.62ms -step:758/1695 train_time:73239ms step_avg:96.62ms -step:759/1695 train_time:73334ms step_avg:96.62ms -step:760/1695 train_time:73429ms step_avg:96.62ms -step:761/1695 train_time:73526ms step_avg:96.62ms -step:762/1695 train_time:73625ms step_avg:96.62ms -step:763/1695 train_time:73723ms step_avg:96.62ms -step:764/1695 train_time:73819ms step_avg:96.62ms -step:765/1695 train_time:73916ms step_avg:96.62ms -step:766/1695 train_time:74011ms step_avg:96.62ms -step:767/1695 train_time:74106ms step_avg:96.62ms -step:768/1695 train_time:74201ms step_avg:96.62ms -step:769/1695 train_time:74296ms step_avg:96.61ms -step:770/1695 train_time:74391ms step_avg:96.61ms -step:771/1695 train_time:74487ms step_avg:96.61ms -step:772/1695 train_time:74584ms step_avg:96.61ms -step:773/1695 train_time:74682ms step_avg:96.61ms -step:774/1695 train_time:74779ms step_avg:96.61ms -step:775/1695 train_time:74875ms step_avg:96.61ms -step:776/1695 train_time:74971ms step_avg:96.61ms -step:777/1695 train_time:75066ms step_avg:96.61ms -step:778/1695 train_time:75161ms step_avg:96.61ms -step:779/1695 train_time:75257ms step_avg:96.61ms -step:780/1695 train_time:75352ms step_avg:96.61ms -step:781/1695 train_time:75446ms step_avg:96.60ms -step:782/1695 train_time:75542ms step_avg:96.60ms -step:783/1695 train_time:75639ms step_avg:96.60ms -step:784/1695 train_time:75736ms step_avg:96.60ms -step:785/1695 train_time:75833ms step_avg:96.60ms -step:786/1695 train_time:75930ms step_avg:96.60ms -step:787/1695 train_time:76024ms step_avg:96.60ms -step:788/1695 train_time:76119ms step_avg:96.60ms -step:789/1695 train_time:76215ms step_avg:96.60ms -step:790/1695 train_time:76309ms step_avg:96.59ms -step:791/1695 train_time:76404ms step_avg:96.59ms -step:792/1695 train_time:76500ms step_avg:96.59ms -step:793/1695 train_time:76597ms step_avg:96.59ms -step:794/1695 train_time:76694ms step_avg:96.59ms -step:795/1695 train_time:76790ms step_avg:96.59ms -step:796/1695 train_time:76886ms step_avg:96.59ms -step:797/1695 train_time:76982ms step_avg:96.59ms -step:798/1695 train_time:77077ms step_avg:96.59ms -step:799/1695 train_time:77173ms step_avg:96.59ms -step:800/1695 train_time:77268ms step_avg:96.58ms -step:801/1695 train_time:77362ms step_avg:96.58ms -step:802/1695 train_time:77458ms step_avg:96.58ms -step:803/1695 train_time:77555ms step_avg:96.58ms -step:804/1695 train_time:77651ms step_avg:96.58ms -step:805/1695 train_time:77747ms step_avg:96.58ms -step:806/1695 train_time:77843ms step_avg:96.58ms -step:807/1695 train_time:77939ms step_avg:96.58ms -step:808/1695 train_time:78036ms step_avg:96.58ms -step:809/1695 train_time:78132ms step_avg:96.58ms -step:810/1695 train_time:78228ms step_avg:96.58ms -step:811/1695 train_time:78322ms step_avg:96.58ms -step:812/1695 train_time:78418ms step_avg:96.57ms -step:813/1695 train_time:78514ms step_avg:96.57ms -step:814/1695 train_time:78611ms step_avg:96.57ms -step:815/1695 train_time:78706ms step_avg:96.57ms -step:816/1695 train_time:78802ms step_avg:96.57ms -step:817/1695 train_time:78899ms step_avg:96.57ms -step:818/1695 train_time:78995ms step_avg:96.57ms -step:819/1695 train_time:79091ms step_avg:96.57ms -step:820/1695 train_time:79186ms step_avg:96.57ms -step:821/1695 train_time:79281ms step_avg:96.57ms -step:822/1695 train_time:79378ms step_avg:96.57ms -step:823/1695 train_time:79474ms step_avg:96.57ms -step:824/1695 train_time:79570ms step_avg:96.57ms -step:825/1695 train_time:79665ms step_avg:96.56ms -step:826/1695 train_time:79761ms step_avg:96.56ms -step:827/1695 train_time:79857ms step_avg:96.56ms -step:828/1695 train_time:79953ms step_avg:96.56ms -step:829/1695 train_time:80049ms step_avg:96.56ms -step:830/1695 train_time:80145ms step_avg:96.56ms -step:831/1695 train_time:80241ms step_avg:96.56ms -step:832/1695 train_time:80337ms step_avg:96.56ms -step:833/1695 train_time:80434ms step_avg:96.56ms -step:834/1695 train_time:80529ms step_avg:96.56ms -step:835/1695 train_time:80624ms step_avg:96.56ms -step:836/1695 train_time:80720ms step_avg:96.55ms -step:837/1695 train_time:80816ms step_avg:96.55ms -step:838/1695 train_time:80912ms step_avg:96.55ms -step:839/1695 train_time:81007ms step_avg:96.55ms -step:840/1695 train_time:81103ms step_avg:96.55ms -step:841/1695 train_time:81199ms step_avg:96.55ms -step:842/1695 train_time:81294ms step_avg:96.55ms -step:843/1695 train_time:81390ms step_avg:96.55ms -step:844/1695 train_time:81485ms step_avg:96.55ms -step:845/1695 train_time:81581ms step_avg:96.55ms -step:846/1695 train_time:81676ms step_avg:96.54ms -step:847/1695 train_time:81772ms step_avg:96.54ms -step:848/1695 train_time:81867ms step_avg:96.54ms -step:849/1695 train_time:81963ms step_avg:96.54ms -step:850/1695 train_time:82059ms step_avg:96.54ms -step:851/1695 train_time:82155ms step_avg:96.54ms -step:852/1695 train_time:82251ms step_avg:96.54ms -step:853/1695 train_time:82347ms step_avg:96.54ms -step:854/1695 train_time:82443ms step_avg:96.54ms -step:855/1695 train_time:82539ms step_avg:96.54ms -step:856/1695 train_time:82635ms step_avg:96.54ms -step:857/1695 train_time:82730ms step_avg:96.53ms -step:858/1695 train_time:82826ms step_avg:96.53ms -step:859/1695 train_time:82921ms step_avg:96.53ms -step:860/1695 train_time:83017ms step_avg:96.53ms -step:861/1695 train_time:83113ms step_avg:96.53ms -step:862/1695 train_time:83208ms step_avg:96.53ms -step:863/1695 train_time:83631ms step_avg:96.91ms -step:864/1695 train_time:83735ms step_avg:96.92ms -step:865/1695 train_time:83829ms step_avg:96.91ms -step:866/1695 train_time:83923ms step_avg:96.91ms -step:867/1695 train_time:84018ms step_avg:96.91ms -step:868/1695 train_time:84113ms step_avg:96.90ms -step:869/1695 train_time:84207ms step_avg:96.90ms -step:870/1695 train_time:84301ms step_avg:96.90ms -step:871/1695 train_time:84397ms step_avg:96.90ms -step:872/1695 train_time:84492ms step_avg:96.89ms -step:873/1695 train_time:84593ms step_avg:96.90ms -step:874/1695 train_time:84691ms step_avg:96.90ms -step:875/1695 train_time:84788ms step_avg:96.90ms -step:875/1695 val_loss:3.5240 train_time:84881ms step_avg:97.01ms -step:876/1695 train_time:84907ms step_avg:96.93ms -step:877/1695 train_time:84991ms step_avg:96.91ms -step:878/1695 train_time:85089ms step_avg:96.91ms -step:879/1695 train_time:85186ms step_avg:96.91ms -step:880/1695 train_time:85282ms step_avg:96.91ms -step:881/1695 train_time:85377ms step_avg:96.91ms -step:882/1695 train_time:85472ms step_avg:96.91ms -step:883/1695 train_time:85567ms step_avg:96.90ms -step:884/1695 train_time:85663ms step_avg:96.90ms -step:885/1695 train_time:85757ms step_avg:96.90ms -step:886/1695 train_time:85854ms step_avg:96.90ms -step:887/1695 train_time:85952ms step_avg:96.90ms -step:888/1695 train_time:86049ms step_avg:96.90ms -step:889/1695 train_time:86145ms step_avg:96.90ms -step:890/1695 train_time:86242ms step_avg:96.90ms -step:891/1695 train_time:86339ms step_avg:96.90ms -step:892/1695 train_time:86435ms step_avg:96.90ms -step:893/1695 train_time:86530ms step_avg:96.90ms -step:894/1695 train_time:86625ms step_avg:96.90ms -step:895/1695 train_time:86720ms step_avg:96.89ms -step:896/1695 train_time:86816ms step_avg:96.89ms -step:897/1695 train_time:86913ms step_avg:96.89ms -step:898/1695 train_time:87009ms step_avg:96.89ms -step:899/1695 train_time:87105ms step_avg:96.89ms -step:900/1695 train_time:87201ms step_avg:96.89ms -step:901/1695 train_time:87298ms step_avg:96.89ms -step:902/1695 train_time:87394ms step_avg:96.89ms -step:903/1695 train_time:87489ms step_avg:96.89ms -step:904/1695 train_time:87584ms step_avg:96.88ms -step:905/1695 train_time:87679ms step_avg:96.88ms -step:906/1695 train_time:87776ms step_avg:96.88ms -step:907/1695 train_time:87872ms step_avg:96.88ms -step:908/1695 train_time:87968ms step_avg:96.88ms -step:909/1695 train_time:88065ms step_avg:96.88ms -step:910/1695 train_time:88162ms step_avg:96.88ms -step:911/1695 train_time:88259ms step_avg:96.88ms -step:912/1695 train_time:88354ms step_avg:96.88ms -step:913/1695 train_time:88449ms step_avg:96.88ms -step:914/1695 train_time:88544ms step_avg:96.88ms -step:915/1695 train_time:88640ms step_avg:96.87ms -step:916/1695 train_time:88736ms step_avg:96.87ms -step:917/1695 train_time:88831ms step_avg:96.87ms -step:918/1695 train_time:88927ms step_avg:96.87ms -step:919/1695 train_time:89023ms step_avg:96.87ms -step:920/1695 train_time:89119ms step_avg:96.87ms -step:921/1695 train_time:89216ms step_avg:96.87ms -step:922/1695 train_time:89312ms step_avg:96.87ms -step:923/1695 train_time:89407ms step_avg:96.87ms -step:924/1695 train_time:89504ms step_avg:96.87ms -step:925/1695 train_time:89600ms step_avg:96.86ms -step:926/1695 train_time:89696ms step_avg:96.86ms -step:927/1695 train_time:89792ms step_avg:96.86ms -step:928/1695 train_time:89887ms step_avg:96.86ms -step:929/1695 train_time:89983ms step_avg:96.86ms -step:930/1695 train_time:90080ms step_avg:96.86ms -step:931/1695 train_time:90177ms step_avg:96.86ms -step:932/1695 train_time:90273ms step_avg:96.86ms -step:933/1695 train_time:90369ms step_avg:96.86ms -step:934/1695 train_time:90465ms step_avg:96.86ms -step:935/1695 train_time:90561ms step_avg:96.86ms -step:936/1695 train_time:90657ms step_avg:96.86ms -step:937/1695 train_time:90753ms step_avg:96.86ms -step:938/1695 train_time:90848ms step_avg:96.85ms -step:939/1695 train_time:90944ms step_avg:96.85ms -step:940/1695 train_time:91040ms step_avg:96.85ms -step:941/1695 train_time:91136ms step_avg:96.85ms -step:942/1695 train_time:91232ms step_avg:96.85ms -step:943/1695 train_time:91329ms step_avg:96.85ms -step:944/1695 train_time:91424ms step_avg:96.85ms -step:945/1695 train_time:91520ms step_avg:96.85ms -step:946/1695 train_time:91616ms step_avg:96.85ms -step:947/1695 train_time:91711ms step_avg:96.84ms -step:948/1695 train_time:91807ms step_avg:96.84ms -step:949/1695 train_time:91904ms step_avg:96.84ms -step:950/1695 train_time:92001ms step_avg:96.84ms -step:951/1695 train_time:92098ms step_avg:96.84ms -step:952/1695 train_time:92194ms step_avg:96.84ms -step:953/1695 train_time:92290ms step_avg:96.84ms -step:954/1695 train_time:92385ms step_avg:96.84ms -step:955/1695 train_time:92481ms step_avg:96.84ms -step:956/1695 train_time:92578ms step_avg:96.84ms -step:957/1695 train_time:92675ms step_avg:96.84ms -step:958/1695 train_time:92770ms step_avg:96.84ms -step:959/1695 train_time:92866ms step_avg:96.84ms -step:960/1695 train_time:92963ms step_avg:96.84ms -step:961/1695 train_time:93060ms step_avg:96.84ms -step:962/1695 train_time:93158ms step_avg:96.84ms -step:963/1695 train_time:93254ms step_avg:96.84ms -step:964/1695 train_time:93349ms step_avg:96.83ms -step:965/1695 train_time:93444ms step_avg:96.83ms -step:966/1695 train_time:93540ms step_avg:96.83ms -step:967/1695 train_time:93636ms step_avg:96.83ms -step:968/1695 train_time:93733ms step_avg:96.83ms -step:969/1695 train_time:93828ms step_avg:96.83ms -step:970/1695 train_time:93924ms step_avg:96.83ms -step:971/1695 train_time:94021ms step_avg:96.83ms -step:972/1695 train_time:94119ms step_avg:96.83ms -step:973/1695 train_time:94215ms step_avg:96.83ms -step:974/1695 train_time:94310ms step_avg:96.83ms -step:975/1695 train_time:94406ms step_avg:96.83ms -step:976/1695 train_time:94501ms step_avg:96.82ms -step:977/1695 train_time:94597ms step_avg:96.82ms -step:978/1695 train_time:94693ms step_avg:96.82ms -step:979/1695 train_time:94788ms step_avg:96.82ms -step:980/1695 train_time:94884ms step_avg:96.82ms -step:981/1695 train_time:94980ms step_avg:96.82ms -step:982/1695 train_time:95077ms step_avg:96.82ms -step:983/1695 train_time:95173ms step_avg:96.82ms -step:984/1695 train_time:95268ms step_avg:96.82ms -step:985/1695 train_time:95364ms step_avg:96.82ms -step:986/1695 train_time:95460ms step_avg:96.82ms -step:987/1695 train_time:95556ms step_avg:96.81ms -step:988/1695 train_time:95652ms step_avg:96.81ms -step:989/1695 train_time:95747ms step_avg:96.81ms -step:990/1695 train_time:95844ms step_avg:96.81ms -step:991/1695 train_time:95940ms step_avg:96.81ms -step:992/1695 train_time:96037ms step_avg:96.81ms -step:993/1695 train_time:96133ms step_avg:96.81ms -step:994/1695 train_time:96229ms step_avg:96.81ms -step:995/1695 train_time:96325ms step_avg:96.81ms -step:996/1695 train_time:96422ms step_avg:96.81ms -step:997/1695 train_time:96518ms step_avg:96.81ms -step:998/1695 train_time:96613ms step_avg:96.81ms -step:999/1695 train_time:96708ms step_avg:96.80ms -step:1000/1695 train_time:96804ms step_avg:96.80ms -step:1000/1695 val_loss:3.4845 train_time:96898ms step_avg:96.90ms -step:1001/1695 train_time:96924ms step_avg:96.83ms -step:1002/1695 train_time:97001ms step_avg:96.81ms -step:1003/1695 train_time:97099ms step_avg:96.81ms -step:1004/1695 train_time:97195ms step_avg:96.81ms -step:1005/1695 train_time:97291ms step_avg:96.81ms -step:1006/1695 train_time:97386ms step_avg:96.81ms -step:1007/1695 train_time:97481ms step_avg:96.80ms -step:1008/1695 train_time:97576ms step_avg:96.80ms -step:1009/1695 train_time:97672ms step_avg:96.80ms -step:1010/1695 train_time:97766ms step_avg:96.80ms -step:1011/1695 train_time:97863ms step_avg:96.80ms -step:1012/1695 train_time:97961ms step_avg:96.80ms -step:1013/1695 train_time:98057ms step_avg:96.80ms -step:1014/1695 train_time:98153ms step_avg:96.80ms -step:1015/1695 train_time:98249ms step_avg:96.80ms -step:1016/1695 train_time:98346ms step_avg:96.80ms -step:1017/1695 train_time:98440ms step_avg:96.79ms -step:1018/1695 train_time:98535ms step_avg:96.79ms -step:1019/1695 train_time:98632ms step_avg:96.79ms -step:1020/1695 train_time:98728ms step_avg:96.79ms -step:1021/1695 train_time:98823ms step_avg:96.79ms -step:1022/1695 train_time:98919ms step_avg:96.79ms -step:1023/1695 train_time:99016ms step_avg:96.79ms -step:1024/1695 train_time:99114ms step_avg:96.79ms -step:1025/1695 train_time:99211ms step_avg:96.79ms -step:1026/1695 train_time:99307ms step_avg:96.79ms -step:1027/1695 train_time:99402ms step_avg:96.79ms -step:1028/1695 train_time:99497ms step_avg:96.79ms -step:1029/1695 train_time:99593ms step_avg:96.79ms -step:1030/1695 train_time:99690ms step_avg:96.79ms -step:1031/1695 train_time:99786ms step_avg:96.79ms -step:1032/1695 train_time:99882ms step_avg:96.78ms -step:1033/1695 train_time:99978ms step_avg:96.78ms -step:1034/1695 train_time:100075ms step_avg:96.78ms -step:1035/1695 train_time:100172ms step_avg:96.78ms -step:1036/1695 train_time:100506ms step_avg:97.01ms -step:1037/1695 train_time:100695ms step_avg:97.10ms -step:1038/1695 train_time:100788ms step_avg:97.10ms -step:1039/1695 train_time:100882ms step_avg:97.10ms -step:1040/1695 train_time:100977ms step_avg:97.09ms -step:1041/1695 train_time:101073ms step_avg:97.09ms -step:1042/1695 train_time:101168ms step_avg:97.09ms -step:1043/1695 train_time:101262ms step_avg:97.09ms -step:1044/1695 train_time:101357ms step_avg:97.09ms -step:1045/1695 train_time:101452ms step_avg:97.08ms -step:1046/1695 train_time:101548ms step_avg:97.08ms -step:1047/1695 train_time:101649ms step_avg:97.09ms -step:1048/1695 train_time:101747ms step_avg:97.09ms -step:1049/1695 train_time:101842ms step_avg:97.09ms -step:1050/1695 train_time:101937ms step_avg:97.08ms -step:1051/1695 train_time:102034ms step_avg:97.08ms -step:1052/1695 train_time:102130ms step_avg:97.08ms -step:1053/1695 train_time:102225ms step_avg:97.08ms -step:1054/1695 train_time:102320ms step_avg:97.08ms -step:1055/1695 train_time:102415ms step_avg:97.08ms -step:1056/1695 train_time:102511ms step_avg:97.07ms -step:1057/1695 train_time:102610ms step_avg:97.08ms -step:1058/1695 train_time:102707ms step_avg:97.08ms -step:1059/1695 train_time:102804ms step_avg:97.08ms -step:1060/1695 train_time:102900ms step_avg:97.08ms -step:1061/1695 train_time:102996ms step_avg:97.07ms -step:1062/1695 train_time:103092ms step_avg:97.07ms -step:1063/1695 train_time:103187ms step_avg:97.07ms -step:1064/1695 train_time:103282ms step_avg:97.07ms -step:1065/1695 train_time:103377ms step_avg:97.07ms -step:1066/1695 train_time:103473ms step_avg:97.07ms -step:1067/1695 train_time:103569ms step_avg:97.07ms -step:1068/1695 train_time:103666ms step_avg:97.07ms -step:1069/1695 train_time:103762ms step_avg:97.06ms -step:1070/1695 train_time:103858ms step_avg:97.06ms -step:1071/1695 train_time:103954ms step_avg:97.06ms -step:1072/1695 train_time:104050ms step_avg:97.06ms -step:1073/1695 train_time:104145ms step_avg:97.06ms -step:1074/1695 train_time:104241ms step_avg:97.06ms -step:1075/1695 train_time:104336ms step_avg:97.06ms -step:1076/1695 train_time:104432ms step_avg:97.06ms -step:1077/1695 train_time:104527ms step_avg:97.05ms -step:1078/1695 train_time:104624ms step_avg:97.05ms -step:1079/1695 train_time:104719ms step_avg:97.05ms -step:1080/1695 train_time:104817ms step_avg:97.05ms -step:1081/1695 train_time:104915ms step_avg:97.05ms -step:1082/1695 train_time:105011ms step_avg:97.05ms -step:1083/1695 train_time:105107ms step_avg:97.05ms -step:1084/1695 train_time:105203ms step_avg:97.05ms -step:1085/1695 train_time:105298ms step_avg:97.05ms -step:1086/1695 train_time:105394ms step_avg:97.05ms -step:1087/1695 train_time:105489ms step_avg:97.05ms -step:1088/1695 train_time:105585ms step_avg:97.05ms -step:1089/1695 train_time:105680ms step_avg:97.04ms -step:1090/1695 train_time:105776ms step_avg:97.04ms -step:1091/1695 train_time:105873ms step_avg:97.04ms -step:1092/1695 train_time:105970ms step_avg:97.04ms -step:1093/1695 train_time:106065ms step_avg:97.04ms -step:1094/1695 train_time:106160ms step_avg:97.04ms -step:1095/1695 train_time:106255ms step_avg:97.04ms -step:1096/1695 train_time:106351ms step_avg:97.04ms -step:1097/1695 train_time:106448ms step_avg:97.04ms -step:1098/1695 train_time:106544ms step_avg:97.03ms -step:1099/1695 train_time:106640ms step_avg:97.03ms -step:1100/1695 train_time:106735ms step_avg:97.03ms -step:1101/1695 train_time:106832ms step_avg:97.03ms -step:1102/1695 train_time:106929ms step_avg:97.03ms -step:1103/1695 train_time:107026ms step_avg:97.03ms -step:1104/1695 train_time:107121ms step_avg:97.03ms -step:1105/1695 train_time:107216ms step_avg:97.03ms -step:1106/1695 train_time:107312ms step_avg:97.03ms -step:1107/1695 train_time:107408ms step_avg:97.03ms -step:1108/1695 train_time:107503ms step_avg:97.02ms -step:1109/1695 train_time:107599ms step_avg:97.02ms -step:1110/1695 train_time:107695ms step_avg:97.02ms -step:1111/1695 train_time:107792ms step_avg:97.02ms -step:1112/1695 train_time:107889ms step_avg:97.02ms -step:1113/1695 train_time:107985ms step_avg:97.02ms -step:1114/1695 train_time:108080ms step_avg:97.02ms -step:1115/1695 train_time:108176ms step_avg:97.02ms -step:1116/1695 train_time:108272ms step_avg:97.02ms -step:1117/1695 train_time:108369ms step_avg:97.02ms -step:1118/1695 train_time:108464ms step_avg:97.02ms -step:1119/1695 train_time:108560ms step_avg:97.02ms -step:1120/1695 train_time:108655ms step_avg:97.01ms -step:1121/1695 train_time:108752ms step_avg:97.01ms -step:1122/1695 train_time:108848ms step_avg:97.01ms -step:1123/1695 train_time:108944ms step_avg:97.01ms -step:1124/1695 train_time:109040ms step_avg:97.01ms -step:1125/1695 train_time:109136ms step_avg:97.01ms -step:1125/1695 val_loss:3.4374 train_time:109230ms step_avg:97.09ms -step:1126/1695 train_time:109255ms step_avg:97.03ms -step:1127/1695 train_time:109339ms step_avg:97.02ms -step:1128/1695 train_time:109436ms step_avg:97.02ms -step:1129/1695 train_time:109533ms step_avg:97.02ms -step:1130/1695 train_time:109628ms step_avg:97.02ms -step:1131/1695 train_time:109724ms step_avg:97.01ms -step:1132/1695 train_time:109818ms step_avg:97.01ms -step:1133/1695 train_time:109915ms step_avg:97.01ms -step:1134/1695 train_time:110012ms step_avg:97.01ms -step:1135/1695 train_time:110108ms step_avg:97.01ms -step:1136/1695 train_time:110207ms step_avg:97.01ms -step:1137/1695 train_time:110306ms step_avg:97.02ms -step:1138/1695 train_time:110406ms step_avg:97.02ms -step:1139/1695 train_time:110505ms step_avg:97.02ms -step:1140/1695 train_time:110602ms step_avg:97.02ms -step:1141/1695 train_time:110699ms step_avg:97.02ms -step:1142/1695 train_time:110796ms step_avg:97.02ms -step:1143/1695 train_time:110894ms step_avg:97.02ms -step:1144/1695 train_time:110992ms step_avg:97.02ms -step:1145/1695 train_time:111089ms step_avg:97.02ms -step:1146/1695 train_time:111188ms step_avg:97.02ms -step:1147/1695 train_time:111286ms step_avg:97.02ms -step:1148/1695 train_time:111384ms step_avg:97.02ms -step:1149/1695 train_time:111483ms step_avg:97.03ms -step:1150/1695 train_time:111581ms step_avg:97.03ms -step:1151/1695 train_time:111678ms step_avg:97.03ms -step:1152/1695 train_time:111776ms step_avg:97.03ms -step:1153/1695 train_time:111873ms step_avg:97.03ms -step:1154/1695 train_time:111970ms step_avg:97.03ms -step:1155/1695 train_time:112067ms step_avg:97.03ms -step:1156/1695 train_time:112165ms step_avg:97.03ms -step:1157/1695 train_time:112264ms step_avg:97.03ms -step:1158/1695 train_time:112361ms step_avg:97.03ms -step:1159/1695 train_time:112460ms step_avg:97.03ms -step:1160/1695 train_time:112558ms step_avg:97.03ms -step:1161/1695 train_time:112656ms step_avg:97.03ms -step:1162/1695 train_time:112754ms step_avg:97.03ms -step:1163/1695 train_time:112852ms step_avg:97.04ms -step:1164/1695 train_time:112949ms step_avg:97.04ms -step:1165/1695 train_time:113046ms step_avg:97.04ms -step:1166/1695 train_time:113143ms step_avg:97.04ms -step:1167/1695 train_time:113241ms step_avg:97.04ms -step:1168/1695 train_time:113339ms step_avg:97.04ms -step:1169/1695 train_time:113437ms step_avg:97.04ms -step:1170/1695 train_time:113535ms step_avg:97.04ms -step:1171/1695 train_time:113634ms step_avg:97.04ms -step:1172/1695 train_time:113733ms step_avg:97.04ms -step:1173/1695 train_time:113830ms step_avg:97.04ms -step:1174/1695 train_time:113927ms step_avg:97.04ms -step:1175/1695 train_time:114025ms step_avg:97.04ms -step:1176/1695 train_time:114121ms step_avg:97.04ms -step:1177/1695 train_time:114219ms step_avg:97.04ms -step:1178/1695 train_time:114317ms step_avg:97.04ms -step:1179/1695 train_time:114416ms step_avg:97.05ms -step:1180/1695 train_time:114515ms step_avg:97.05ms -step:1181/1695 train_time:114613ms step_avg:97.05ms -step:1182/1695 train_time:114712ms step_avg:97.05ms -step:1183/1695 train_time:114809ms step_avg:97.05ms -step:1184/1695 train_time:114908ms step_avg:97.05ms -step:1185/1695 train_time:115005ms step_avg:97.05ms -step:1186/1695 train_time:115102ms step_avg:97.05ms -step:1187/1695 train_time:115199ms step_avg:97.05ms -step:1188/1695 train_time:115296ms step_avg:97.05ms -step:1189/1695 train_time:115395ms step_avg:97.05ms -step:1190/1695 train_time:115494ms step_avg:97.05ms -step:1191/1695 train_time:115593ms step_avg:97.06ms -step:1192/1695 train_time:115692ms step_avg:97.06ms -step:1193/1695 train_time:115790ms step_avg:97.06ms -step:1194/1695 train_time:115888ms step_avg:97.06ms -step:1195/1695 train_time:115986ms step_avg:97.06ms -step:1196/1695 train_time:116084ms step_avg:97.06ms -step:1197/1695 train_time:116183ms step_avg:97.06ms -step:1198/1695 train_time:116280ms step_avg:97.06ms -step:1199/1695 train_time:116378ms step_avg:97.06ms -step:1200/1695 train_time:116476ms step_avg:97.06ms -step:1201/1695 train_time:116574ms step_avg:97.06ms -step:1202/1695 train_time:116672ms step_avg:97.07ms -step:1203/1695 train_time:116770ms step_avg:97.07ms -step:1204/1695 train_time:116868ms step_avg:97.07ms -step:1205/1695 train_time:116966ms step_avg:97.07ms -step:1206/1695 train_time:117064ms step_avg:97.07ms -step:1207/1695 train_time:117162ms step_avg:97.07ms -step:1208/1695 train_time:117508ms step_avg:97.27ms -step:1209/1695 train_time:117691ms step_avg:97.35ms -step:1210/1695 train_time:117787ms step_avg:97.34ms -step:1211/1695 train_time:117883ms step_avg:97.34ms -step:1212/1695 train_time:117979ms step_avg:97.34ms -step:1213/1695 train_time:118076ms step_avg:97.34ms -step:1214/1695 train_time:118174ms step_avg:97.34ms -step:1215/1695 train_time:118270ms step_avg:97.34ms -step:1216/1695 train_time:118367ms step_avg:97.34ms -step:1217/1695 train_time:118463ms step_avg:97.34ms -step:1218/1695 train_time:118566ms step_avg:97.34ms -step:1219/1695 train_time:118667ms step_avg:97.35ms -step:1220/1695 train_time:118767ms step_avg:97.35ms -step:1221/1695 train_time:118863ms step_avg:97.35ms -step:1222/1695 train_time:118959ms step_avg:97.35ms -step:1223/1695 train_time:119056ms step_avg:97.35ms -step:1224/1695 train_time:119152ms step_avg:97.35ms -step:1225/1695 train_time:119249ms step_avg:97.35ms -step:1226/1695 train_time:119346ms step_avg:97.35ms -step:1227/1695 train_time:119443ms step_avg:97.35ms -step:1228/1695 train_time:119541ms step_avg:97.35ms -step:1229/1695 train_time:119640ms step_avg:97.35ms -step:1230/1695 train_time:119740ms step_avg:97.35ms -step:1231/1695 train_time:119838ms step_avg:97.35ms -step:1232/1695 train_time:119936ms step_avg:97.35ms -step:1233/1695 train_time:120033ms step_avg:97.35ms -step:1234/1695 train_time:120130ms step_avg:97.35ms -step:1235/1695 train_time:120227ms step_avg:97.35ms -step:1236/1695 train_time:120323ms step_avg:97.35ms -step:1237/1695 train_time:120420ms step_avg:97.35ms -step:1238/1695 train_time:120517ms step_avg:97.35ms -step:1239/1695 train_time:120616ms step_avg:97.35ms -step:1240/1695 train_time:120716ms step_avg:97.35ms -step:1241/1695 train_time:120815ms step_avg:97.35ms -step:1242/1695 train_time:120914ms step_avg:97.35ms -step:1243/1695 train_time:121012ms step_avg:97.36ms -step:1244/1695 train_time:121110ms step_avg:97.35ms -step:1245/1695 train_time:121206ms step_avg:97.35ms -step:1246/1695 train_time:121304ms step_avg:97.35ms -step:1247/1695 train_time:121400ms step_avg:97.35ms -step:1248/1695 train_time:121498ms step_avg:97.35ms -step:1249/1695 train_time:121596ms step_avg:97.35ms -step:1250/1695 train_time:121696ms step_avg:97.36ms -step:1250/1695 val_loss:3.3886 train_time:121792ms step_avg:97.43ms -step:1251/1695 train_time:121818ms step_avg:97.38ms -step:1252/1695 train_time:121899ms step_avg:97.36ms -step:1253/1695 train_time:121997ms step_avg:97.36ms -step:1254/1695 train_time:122094ms step_avg:97.36ms -step:1255/1695 train_time:122190ms step_avg:97.36ms -step:1256/1695 train_time:122287ms step_avg:97.36ms -step:1257/1695 train_time:122383ms step_avg:97.36ms -step:1258/1695 train_time:122480ms step_avg:97.36ms -step:1259/1695 train_time:122576ms step_avg:97.36ms -step:1260/1695 train_time:122673ms step_avg:97.36ms -step:1261/1695 train_time:122774ms step_avg:97.36ms -step:1262/1695 train_time:122874ms step_avg:97.36ms -step:1263/1695 train_time:122972ms step_avg:97.36ms -step:1264/1695 train_time:123070ms step_avg:97.37ms -step:1265/1695 train_time:123167ms step_avg:97.37ms -step:1266/1695 train_time:123265ms step_avg:97.37ms -step:1267/1695 train_time:123361ms step_avg:97.36ms -step:1268/1695 train_time:123458ms step_avg:97.36ms -step:1269/1695 train_time:123554ms step_avg:97.36ms -step:1270/1695 train_time:123650ms step_avg:97.36ms -step:1271/1695 train_time:123749ms step_avg:97.36ms -step:1272/1695 train_time:123849ms step_avg:97.37ms -step:1273/1695 train_time:123947ms step_avg:97.37ms -step:1274/1695 train_time:124045ms step_avg:97.37ms -step:1275/1695 train_time:124143ms step_avg:97.37ms -step:1276/1695 train_time:124242ms step_avg:97.37ms -step:1277/1695 train_time:124340ms step_avg:97.37ms -step:1278/1695 train_time:124437ms step_avg:97.37ms -step:1279/1695 train_time:124534ms step_avg:97.37ms -step:1280/1695 train_time:124632ms step_avg:97.37ms -step:1281/1695 train_time:124730ms step_avg:97.37ms -step:1282/1695 train_time:124828ms step_avg:97.37ms -step:1283/1695 train_time:124926ms step_avg:97.37ms -step:1284/1695 train_time:125024ms step_avg:97.37ms -step:1285/1695 train_time:125122ms step_avg:97.37ms -step:1286/1695 train_time:125220ms step_avg:97.37ms -step:1287/1695 train_time:125318ms step_avg:97.37ms -step:1288/1695 train_time:125415ms step_avg:97.37ms -step:1289/1695 train_time:125512ms step_avg:97.37ms -step:1290/1695 train_time:125609ms step_avg:97.37ms -step:1291/1695 train_time:125707ms step_avg:97.37ms -step:1292/1695 train_time:125806ms step_avg:97.37ms -step:1293/1695 train_time:125904ms step_avg:97.37ms -step:1294/1695 train_time:126003ms step_avg:97.37ms -step:1295/1695 train_time:126102ms step_avg:97.38ms -step:1296/1695 train_time:126200ms step_avg:97.38ms -step:1297/1695 train_time:126297ms step_avg:97.38ms -step:1298/1695 train_time:126395ms step_avg:97.38ms -step:1299/1695 train_time:126492ms step_avg:97.38ms -step:1300/1695 train_time:126589ms step_avg:97.38ms -step:1301/1695 train_time:126686ms step_avg:97.38ms -step:1302/1695 train_time:126785ms step_avg:97.38ms -step:1303/1695 train_time:126883ms step_avg:97.38ms -step:1304/1695 train_time:126981ms step_avg:97.38ms -step:1305/1695 train_time:127079ms step_avg:97.38ms -step:1306/1695 train_time:127178ms step_avg:97.38ms -step:1307/1695 train_time:127276ms step_avg:97.38ms -step:1308/1695 train_time:127373ms step_avg:97.38ms -step:1309/1695 train_time:127470ms step_avg:97.38ms -step:1310/1695 train_time:127568ms step_avg:97.38ms -step:1311/1695 train_time:127665ms step_avg:97.38ms -step:1312/1695 train_time:127764ms step_avg:97.38ms -step:1313/1695 train_time:127862ms step_avg:97.38ms -step:1314/1695 train_time:127961ms step_avg:97.38ms -step:1315/1695 train_time:128059ms step_avg:97.38ms -step:1316/1695 train_time:128158ms step_avg:97.38ms -step:1317/1695 train_time:128255ms step_avg:97.38ms -step:1318/1695 train_time:128352ms step_avg:97.38ms -step:1319/1695 train_time:128449ms step_avg:97.38ms -step:1320/1695 train_time:128546ms step_avg:97.38ms -step:1321/1695 train_time:128644ms step_avg:97.38ms -step:1322/1695 train_time:128743ms step_avg:97.38ms -step:1323/1695 train_time:128841ms step_avg:97.39ms -step:1324/1695 train_time:128939ms step_avg:97.39ms -step:1325/1695 train_time:129038ms step_avg:97.39ms -step:1326/1695 train_time:129136ms step_avg:97.39ms -step:1327/1695 train_time:129233ms step_avg:97.39ms -step:1328/1695 train_time:129332ms step_avg:97.39ms -step:1329/1695 train_time:129429ms step_avg:97.39ms -step:1330/1695 train_time:129525ms step_avg:97.39ms -step:1331/1695 train_time:129623ms step_avg:97.39ms -step:1332/1695 train_time:129721ms step_avg:97.39ms -step:1333/1695 train_time:129820ms step_avg:97.39ms -step:1334/1695 train_time:129917ms step_avg:97.39ms -step:1335/1695 train_time:130015ms step_avg:97.39ms -step:1336/1695 train_time:130112ms step_avg:97.39ms -step:1337/1695 train_time:130209ms step_avg:97.39ms -step:1338/1695 train_time:130307ms step_avg:97.39ms -step:1339/1695 train_time:130405ms step_avg:97.39ms -step:1340/1695 train_time:130503ms step_avg:97.39ms -step:1341/1695 train_time:130601ms step_avg:97.39ms -step:1342/1695 train_time:130699ms step_avg:97.39ms -step:1343/1695 train_time:130797ms step_avg:97.39ms -step:1344/1695 train_time:130894ms step_avg:97.39ms -step:1345/1695 train_time:130991ms step_avg:97.39ms -step:1346/1695 train_time:131088ms step_avg:97.39ms -step:1347/1695 train_time:131185ms step_avg:97.39ms -step:1348/1695 train_time:131283ms step_avg:97.39ms -step:1349/1695 train_time:131381ms step_avg:97.39ms -step:1350/1695 train_time:131481ms step_avg:97.39ms -step:1351/1695 train_time:131579ms step_avg:97.39ms -step:1352/1695 train_time:131676ms step_avg:97.39ms -step:1353/1695 train_time:131774ms step_avg:97.39ms -step:1354/1695 train_time:131871ms step_avg:97.39ms -step:1355/1695 train_time:131969ms step_avg:97.39ms -step:1356/1695 train_time:132066ms step_avg:97.39ms -step:1357/1695 train_time:132164ms step_avg:97.39ms -step:1358/1695 train_time:132262ms step_avg:97.39ms -step:1359/1695 train_time:132360ms step_avg:97.40ms -step:1360/1695 train_time:132459ms step_avg:97.40ms -step:1361/1695 train_time:132558ms step_avg:97.40ms -step:1362/1695 train_time:132655ms step_avg:97.40ms -step:1363/1695 train_time:132753ms step_avg:97.40ms -step:1364/1695 train_time:132850ms step_avg:97.40ms -step:1365/1695 train_time:132948ms step_avg:97.40ms -step:1366/1695 train_time:133046ms step_avg:97.40ms -step:1367/1695 train_time:133144ms step_avg:97.40ms -step:1368/1695 train_time:133242ms step_avg:97.40ms -step:1369/1695 train_time:133340ms step_avg:97.40ms -step:1370/1695 train_time:133439ms step_avg:97.40ms -step:1371/1695 train_time:133537ms step_avg:97.40ms -step:1372/1695 train_time:133634ms step_avg:97.40ms -step:1373/1695 train_time:133732ms step_avg:97.40ms -step:1374/1695 train_time:133829ms step_avg:97.40ms -step:1375/1695 train_time:133926ms step_avg:97.40ms -step:1375/1695 val_loss:3.3508 train_time:134022ms step_avg:97.47ms -step:1376/1695 train_time:134049ms step_avg:97.42ms -step:1377/1695 train_time:134131ms step_avg:97.41ms -step:1378/1695 train_time:134229ms step_avg:97.41ms -step:1379/1695 train_time:134327ms step_avg:97.41ms -step:1380/1695 train_time:134424ms step_avg:97.41ms -step:1381/1695 train_time:134877ms step_avg:97.67ms -step:1382/1695 train_time:134952ms step_avg:97.65ms -step:1383/1695 train_time:135047ms step_avg:97.65ms -step:1384/1695 train_time:135144ms step_avg:97.65ms -step:1385/1695 train_time:135241ms step_avg:97.65ms -step:1386/1695 train_time:135338ms step_avg:97.65ms -step:1387/1695 train_time:135434ms step_avg:97.65ms -step:1388/1695 train_time:135530ms step_avg:97.64ms -step:1389/1695 train_time:135626ms step_avg:97.64ms -step:1390/1695 train_time:135723ms step_avg:97.64ms -step:1391/1695 train_time:135831ms step_avg:97.65ms -step:1392/1695 train_time:135931ms step_avg:97.65ms -step:1393/1695 train_time:136030ms step_avg:97.65ms -step:1394/1695 train_time:136128ms step_avg:97.65ms -step:1395/1695 train_time:136225ms step_avg:97.65ms -step:1396/1695 train_time:136322ms step_avg:97.65ms -step:1397/1695 train_time:136419ms step_avg:97.65ms -step:1398/1695 train_time:136516ms step_avg:97.65ms -step:1399/1695 train_time:136613ms step_avg:97.65ms -step:1400/1695 train_time:136711ms step_avg:97.65ms -step:1401/1695 train_time:136810ms step_avg:97.65ms -step:1402/1695 train_time:136910ms step_avg:97.65ms -step:1403/1695 train_time:137008ms step_avg:97.65ms -step:1404/1695 train_time:137106ms step_avg:97.65ms -step:1405/1695 train_time:137204ms step_avg:97.65ms -step:1406/1695 train_time:137301ms step_avg:97.65ms -step:1407/1695 train_time:137398ms step_avg:97.65ms -step:1408/1695 train_time:137494ms step_avg:97.65ms -step:1409/1695 train_time:137591ms step_avg:97.65ms -step:1410/1695 train_time:137688ms step_avg:97.65ms -step:1411/1695 train_time:137787ms step_avg:97.65ms -step:1412/1695 train_time:137887ms step_avg:97.65ms -step:1413/1695 train_time:137987ms step_avg:97.66ms -step:1414/1695 train_time:138085ms step_avg:97.66ms -step:1415/1695 train_time:138183ms step_avg:97.66ms -step:1416/1695 train_time:138281ms step_avg:97.66ms -step:1417/1695 train_time:138378ms step_avg:97.66ms -step:1418/1695 train_time:138476ms step_avg:97.66ms -step:1419/1695 train_time:138574ms step_avg:97.66ms -step:1420/1695 train_time:138669ms step_avg:97.65ms -step:1421/1695 train_time:138766ms step_avg:97.65ms -step:1422/1695 train_time:138866ms step_avg:97.66ms -step:1423/1695 train_time:138965ms step_avg:97.66ms -step:1424/1695 train_time:139065ms step_avg:97.66ms -step:1425/1695 train_time:139163ms step_avg:97.66ms -step:1426/1695 train_time:139260ms step_avg:97.66ms -step:1427/1695 train_time:139359ms step_avg:97.66ms -step:1428/1695 train_time:139456ms step_avg:97.66ms -step:1429/1695 train_time:139553ms step_avg:97.66ms -step:1430/1695 train_time:139650ms step_avg:97.66ms -step:1431/1695 train_time:139748ms step_avg:97.66ms -step:1432/1695 train_time:139845ms step_avg:97.66ms -step:1433/1695 train_time:139945ms step_avg:97.66ms -step:1434/1695 train_time:140045ms step_avg:97.66ms -step:1435/1695 train_time:140144ms step_avg:97.66ms -step:1436/1695 train_time:140243ms step_avg:97.66ms -step:1437/1695 train_time:140341ms step_avg:97.66ms -step:1438/1695 train_time:140440ms step_avg:97.66ms -step:1439/1695 train_time:140539ms step_avg:97.66ms -step:1440/1695 train_time:140637ms step_avg:97.66ms -step:1441/1695 train_time:140733ms step_avg:97.66ms -step:1442/1695 train_time:140830ms step_avg:97.66ms -step:1443/1695 train_time:140927ms step_avg:97.66ms -step:1444/1695 train_time:141024ms step_avg:97.66ms -step:1445/1695 train_time:141124ms step_avg:97.66ms -step:1446/1695 train_time:141222ms step_avg:97.66ms -step:1447/1695 train_time:141321ms step_avg:97.66ms -step:1448/1695 train_time:141421ms step_avg:97.67ms -step:1449/1695 train_time:141519ms step_avg:97.67ms -step:1450/1695 train_time:141618ms step_avg:97.67ms -step:1451/1695 train_time:141716ms step_avg:97.67ms -step:1452/1695 train_time:141813ms step_avg:97.67ms -step:1453/1695 train_time:141909ms step_avg:97.67ms -step:1454/1695 train_time:142006ms step_avg:97.67ms -step:1455/1695 train_time:142104ms step_avg:97.67ms -step:1456/1695 train_time:142202ms step_avg:97.67ms -step:1457/1695 train_time:142302ms step_avg:97.67ms -step:1458/1695 train_time:142401ms step_avg:97.67ms -step:1459/1695 train_time:142500ms step_avg:97.67ms -step:1460/1695 train_time:142599ms step_avg:97.67ms -step:1461/1695 train_time:142697ms step_avg:97.67ms -step:1462/1695 train_time:142795ms step_avg:97.67ms -step:1463/1695 train_time:142892ms step_avg:97.67ms -step:1464/1695 train_time:142989ms step_avg:97.67ms -step:1465/1695 train_time:143087ms step_avg:97.67ms -step:1466/1695 train_time:143185ms step_avg:97.67ms -step:1467/1695 train_time:143283ms step_avg:97.67ms -step:1468/1695 train_time:143382ms step_avg:97.67ms -step:1469/1695 train_time:143481ms step_avg:97.67ms -step:1470/1695 train_time:143579ms step_avg:97.67ms -step:1471/1695 train_time:143678ms step_avg:97.67ms -step:1472/1695 train_time:143775ms step_avg:97.67ms -step:1473/1695 train_time:143872ms step_avg:97.67ms -step:1474/1695 train_time:143968ms step_avg:97.67ms -step:1475/1695 train_time:144065ms step_avg:97.67ms -step:1476/1695 train_time:144163ms step_avg:97.67ms -step:1477/1695 train_time:144261ms step_avg:97.67ms -step:1478/1695 train_time:144359ms step_avg:97.67ms -step:1479/1695 train_time:144458ms step_avg:97.67ms -step:1480/1695 train_time:144557ms step_avg:97.67ms -step:1481/1695 train_time:144654ms step_avg:97.67ms -step:1482/1695 train_time:144751ms step_avg:97.67ms -step:1483/1695 train_time:144849ms step_avg:97.67ms -step:1484/1695 train_time:144946ms step_avg:97.67ms -step:1485/1695 train_time:145044ms step_avg:97.67ms -step:1486/1695 train_time:145141ms step_avg:97.67ms -step:1487/1695 train_time:145238ms step_avg:97.67ms -step:1488/1695 train_time:145335ms step_avg:97.67ms -step:1489/1695 train_time:145433ms step_avg:97.67ms -step:1490/1695 train_time:145532ms step_avg:97.67ms -step:1491/1695 train_time:145630ms step_avg:97.67ms -step:1492/1695 train_time:145728ms step_avg:97.67ms -step:1493/1695 train_time:145826ms step_avg:97.67ms -step:1494/1695 train_time:145924ms step_avg:97.67ms -step:1495/1695 train_time:146023ms step_avg:97.67ms -step:1496/1695 train_time:146121ms step_avg:97.67ms -step:1497/1695 train_time:146218ms step_avg:97.67ms -step:1498/1695 train_time:146316ms step_avg:97.67ms -step:1499/1695 train_time:146414ms step_avg:97.67ms -step:1500/1695 train_time:146512ms step_avg:97.67ms -step:1500/1695 val_loss:3.3173 train_time:146608ms step_avg:97.74ms -step:1501/1695 train_time:146635ms step_avg:97.69ms -step:1502/1695 train_time:146715ms step_avg:97.68ms -step:1503/1695 train_time:146815ms step_avg:97.68ms -step:1504/1695 train_time:146912ms step_avg:97.68ms -step:1505/1695 train_time:147009ms step_avg:97.68ms -step:1506/1695 train_time:147105ms step_avg:97.68ms -step:1507/1695 train_time:147202ms step_avg:97.68ms -step:1508/1695 train_time:147299ms step_avg:97.68ms -step:1509/1695 train_time:147395ms step_avg:97.68ms -step:1510/1695 train_time:147491ms step_avg:97.68ms -step:1511/1695 train_time:147591ms step_avg:97.68ms -step:1512/1695 train_time:147692ms step_avg:97.68ms -step:1513/1695 train_time:147792ms step_avg:97.68ms -step:1514/1695 train_time:147890ms step_avg:97.68ms -step:1515/1695 train_time:147987ms step_avg:97.68ms -step:1516/1695 train_time:148084ms step_avg:97.68ms -step:1517/1695 train_time:148181ms step_avg:97.68ms -step:1518/1695 train_time:148278ms step_avg:97.68ms -step:1519/1695 train_time:148375ms step_avg:97.68ms -step:1520/1695 train_time:148472ms step_avg:97.68ms -step:1521/1695 train_time:148570ms step_avg:97.68ms -step:1522/1695 train_time:148669ms step_avg:97.68ms -step:1523/1695 train_time:148769ms step_avg:97.68ms -step:1524/1695 train_time:148867ms step_avg:97.68ms -step:1525/1695 train_time:148965ms step_avg:97.68ms -step:1526/1695 train_time:149062ms step_avg:97.68ms -step:1527/1695 train_time:149159ms step_avg:97.68ms -step:1528/1695 train_time:149256ms step_avg:97.68ms -step:1529/1695 train_time:149352ms step_avg:97.68ms -step:1530/1695 train_time:149448ms step_avg:97.68ms -step:1531/1695 train_time:149546ms step_avg:97.68ms -step:1532/1695 train_time:149645ms step_avg:97.68ms -step:1533/1695 train_time:149745ms step_avg:97.68ms -step:1534/1695 train_time:149844ms step_avg:97.68ms -step:1535/1695 train_time:149943ms step_avg:97.68ms -step:1536/1695 train_time:150041ms step_avg:97.68ms -step:1537/1695 train_time:150138ms step_avg:97.68ms -step:1538/1695 train_time:150235ms step_avg:97.68ms -step:1539/1695 train_time:150333ms step_avg:97.68ms -step:1540/1695 train_time:150429ms step_avg:97.68ms -step:1541/1695 train_time:150527ms step_avg:97.68ms -step:1542/1695 train_time:150625ms step_avg:97.68ms -step:1543/1695 train_time:150724ms step_avg:97.68ms -step:1544/1695 train_time:150823ms step_avg:97.68ms -step:1545/1695 train_time:150922ms step_avg:97.68ms -step:1546/1695 train_time:151020ms step_avg:97.68ms -step:1547/1695 train_time:151118ms step_avg:97.68ms -step:1548/1695 train_time:151214ms step_avg:97.68ms -step:1549/1695 train_time:151311ms step_avg:97.68ms -step:1550/1695 train_time:151407ms step_avg:97.68ms -step:1551/1695 train_time:151505ms step_avg:97.68ms -step:1552/1695 train_time:151855ms step_avg:97.84ms -step:1553/1695 train_time:152033ms step_avg:97.90ms -step:1554/1695 train_time:152130ms step_avg:97.90ms -step:1555/1695 train_time:152226ms step_avg:97.89ms -step:1556/1695 train_time:152323ms step_avg:97.89ms -step:1557/1695 train_time:152420ms step_avg:97.89ms -step:1558/1695 train_time:152516ms step_avg:97.89ms -step:1559/1695 train_time:152612ms step_avg:97.89ms -step:1560/1695 train_time:152708ms step_avg:97.89ms -step:1561/1695 train_time:152804ms step_avg:97.89ms -step:1562/1695 train_time:152908ms step_avg:97.89ms -step:1563/1695 train_time:153011ms step_avg:97.90ms -step:1564/1695 train_time:153110ms step_avg:97.90ms -step:1565/1695 train_time:153207ms step_avg:97.90ms -step:1566/1695 train_time:153305ms step_avg:97.90ms -step:1567/1695 train_time:153401ms step_avg:97.89ms -step:1568/1695 train_time:153499ms step_avg:97.89ms -step:1569/1695 train_time:153596ms step_avg:97.89ms -step:1570/1695 train_time:153694ms step_avg:97.89ms -step:1571/1695 train_time:153791ms step_avg:97.89ms -step:1572/1695 train_time:153889ms step_avg:97.89ms -step:1573/1695 train_time:153989ms step_avg:97.89ms -step:1574/1695 train_time:154087ms step_avg:97.90ms -step:1575/1695 train_time:154185ms step_avg:97.90ms -step:1576/1695 train_time:154283ms step_avg:97.90ms -step:1577/1695 train_time:154380ms step_avg:97.89ms -step:1578/1695 train_time:154477ms step_avg:97.89ms -step:1579/1695 train_time:154574ms step_avg:97.89ms -step:1580/1695 train_time:154672ms step_avg:97.89ms -step:1581/1695 train_time:154769ms step_avg:97.89ms -step:1582/1695 train_time:154867ms step_avg:97.89ms -step:1583/1695 train_time:154967ms step_avg:97.89ms -step:1584/1695 train_time:155066ms step_avg:97.90ms -step:1585/1695 train_time:155164ms step_avg:97.90ms -step:1586/1695 train_time:155262ms step_avg:97.90ms -step:1587/1695 train_time:155360ms step_avg:97.90ms -step:1588/1695 train_time:155457ms step_avg:97.89ms -step:1589/1695 train_time:155554ms step_avg:97.89ms -step:1590/1695 train_time:155651ms step_avg:97.89ms -step:1591/1695 train_time:155749ms step_avg:97.89ms -step:1592/1695 train_time:155846ms step_avg:97.89ms -step:1593/1695 train_time:155945ms step_avg:97.89ms -step:1594/1695 train_time:156044ms step_avg:97.89ms -step:1595/1695 train_time:156143ms step_avg:97.90ms -step:1596/1695 train_time:156240ms step_avg:97.89ms -step:1597/1695 train_time:156338ms step_avg:97.89ms -step:1598/1695 train_time:156435ms step_avg:97.89ms -step:1599/1695 train_time:156532ms step_avg:97.89ms -step:1600/1695 train_time:156629ms step_avg:97.89ms -step:1601/1695 train_time:156727ms step_avg:97.89ms -step:1602/1695 train_time:156825ms step_avg:97.89ms -step:1603/1695 train_time:156923ms step_avg:97.89ms -step:1604/1695 train_time:157022ms step_avg:97.89ms -step:1605/1695 train_time:157121ms step_avg:97.89ms -step:1606/1695 train_time:157220ms step_avg:97.90ms -step:1607/1695 train_time:157318ms step_avg:97.90ms -step:1608/1695 train_time:157416ms step_avg:97.90ms -step:1609/1695 train_time:157515ms step_avg:97.90ms -step:1610/1695 train_time:157611ms step_avg:97.90ms -step:1611/1695 train_time:157709ms step_avg:97.90ms -step:1612/1695 train_time:157806ms step_avg:97.89ms -step:1613/1695 train_time:157904ms step_avg:97.89ms -step:1614/1695 train_time:158003ms step_avg:97.90ms -step:1615/1695 train_time:158102ms step_avg:97.90ms -step:1616/1695 train_time:158200ms step_avg:97.90ms -step:1617/1695 train_time:158299ms step_avg:97.90ms -step:1618/1695 train_time:158398ms step_avg:97.90ms -step:1619/1695 train_time:158495ms step_avg:97.90ms -step:1620/1695 train_time:158593ms step_avg:97.90ms -step:1621/1695 train_time:158691ms step_avg:97.90ms -step:1622/1695 train_time:158787ms step_avg:97.90ms -step:1623/1695 train_time:158885ms step_avg:97.90ms -step:1624/1695 train_time:158983ms step_avg:97.90ms -step:1625/1695 train_time:159082ms step_avg:97.90ms -step:1625/1695 val_loss:3.2899 train_time:159178ms step_avg:97.96ms -step:1626/1695 train_time:159206ms step_avg:97.91ms -step:1627/1695 train_time:159288ms step_avg:97.90ms -step:1628/1695 train_time:159387ms step_avg:97.90ms -step:1629/1695 train_time:159485ms step_avg:97.90ms -step:1630/1695 train_time:159582ms step_avg:97.90ms -step:1631/1695 train_time:159679ms step_avg:97.90ms -step:1632/1695 train_time:159776ms step_avg:97.90ms -step:1633/1695 train_time:159872ms step_avg:97.90ms -step:1634/1695 train_time:159968ms step_avg:97.90ms -step:1635/1695 train_time:160065ms step_avg:97.90ms -step:1636/1695 train_time:160166ms step_avg:97.90ms -step:1637/1695 train_time:160267ms step_avg:97.90ms -step:1638/1695 train_time:160367ms step_avg:97.90ms -step:1639/1695 train_time:160466ms step_avg:97.90ms -step:1640/1695 train_time:160563ms step_avg:97.90ms -step:1641/1695 train_time:160661ms step_avg:97.90ms -step:1642/1695 train_time:160759ms step_avg:97.90ms -step:1643/1695 train_time:160856ms step_avg:97.90ms -step:1644/1695 train_time:160953ms step_avg:97.90ms -step:1645/1695 train_time:161050ms step_avg:97.90ms -step:1646/1695 train_time:161148ms step_avg:97.90ms -step:1647/1695 train_time:161247ms step_avg:97.90ms -step:1648/1695 train_time:161346ms step_avg:97.90ms -step:1649/1695 train_time:161445ms step_avg:97.90ms -step:1650/1695 train_time:161544ms step_avg:97.91ms -step:1651/1695 train_time:161641ms step_avg:97.90ms -step:1652/1695 train_time:161739ms step_avg:97.91ms -step:1653/1695 train_time:161838ms step_avg:97.91ms -step:1654/1695 train_time:161935ms step_avg:97.91ms -step:1655/1695 train_time:162032ms step_avg:97.90ms -step:1656/1695 train_time:162128ms step_avg:97.90ms -step:1657/1695 train_time:162226ms step_avg:97.90ms -step:1658/1695 train_time:162326ms step_avg:97.90ms -step:1659/1695 train_time:162424ms step_avg:97.90ms -step:1660/1695 train_time:162522ms step_avg:97.90ms -step:1661/1695 train_time:162620ms step_avg:97.90ms -step:1662/1695 train_time:162717ms step_avg:97.90ms -step:1663/1695 train_time:162815ms step_avg:97.90ms -step:1664/1695 train_time:162913ms step_avg:97.90ms -step:1665/1695 train_time:163010ms step_avg:97.90ms -step:1666/1695 train_time:163107ms step_avg:97.90ms -step:1667/1695 train_time:163205ms step_avg:97.90ms -step:1668/1695 train_time:163305ms step_avg:97.90ms -step:1669/1695 train_time:163404ms step_avg:97.91ms -step:1670/1695 train_time:163503ms step_avg:97.91ms -step:1671/1695 train_time:163601ms step_avg:97.91ms -step:1672/1695 train_time:163699ms step_avg:97.91ms -step:1673/1695 train_time:163797ms step_avg:97.91ms -step:1674/1695 train_time:163896ms step_avg:97.91ms -step:1675/1695 train_time:163994ms step_avg:97.91ms -step:1676/1695 train_time:164090ms step_avg:97.91ms -step:1677/1695 train_time:164187ms step_avg:97.91ms -step:1678/1695 train_time:164285ms step_avg:97.91ms -step:1679/1695 train_time:164384ms step_avg:97.91ms -step:1680/1695 train_time:164482ms step_avg:97.91ms -step:1681/1695 train_time:164581ms step_avg:97.91ms -step:1682/1695 train_time:164679ms step_avg:97.91ms -step:1683/1695 train_time:164777ms step_avg:97.91ms -step:1684/1695 train_time:164875ms step_avg:97.91ms -step:1685/1695 train_time:164972ms step_avg:97.91ms -step:1686/1695 train_time:165068ms step_avg:97.91ms -step:1687/1695 train_time:165167ms step_avg:97.91ms -step:1688/1695 train_time:165265ms step_avg:97.91ms -step:1689/1695 train_time:165363ms step_avg:97.91ms -step:1690/1695 train_time:165460ms step_avg:97.91ms -step:1691/1695 train_time:165558ms step_avg:97.91ms -step:1692/1695 train_time:165656ms step_avg:97.91ms -step:1693/1695 train_time:165754ms step_avg:97.91ms -step:1694/1695 train_time:165851ms step_avg:97.91ms -step:1695/1695 train_time:165949ms step_avg:97.90ms -step:1695/1695 val_loss:3.2780 train_time:166044ms step_avg:97.96ms -peak memory allocated: 34000 MiB reserved: 49636 MiB diff --git a/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt b/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt deleted file mode 100644 index 789ccc0d1..000000000 --- a/records/082725_FA3/be1069a9-64f4-4316-bd26-4a7f5b697509.txt +++ /dev/null @@ -1,2808 +0,0 @@ -import os -import sys -with open(sys.argv[0]) as f: - code = f.read() # read the code of this file ASAP, for logging -import uuid -import time -import copy -import glob -from dataclasses import dataclass -from functools import lru_cache -from pathlib import Path - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import torch -torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems -from torch import Tensor, nn -import torch.nn.functional as F -import torch.distributed as dist -#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min -import numpy as np -import triton -import triton.language as tl -from flash_attn_interface import flash_attn_func -import torch._dynamo as dynamo -dynamo.config.recompile_limit = 64 - -# ----------------------------------------------------------------------------- -# Custom operators: FP8 matmul by @YouJiacheng - -@torch.library.custom_op("nanogpt::mm", mutates_args=()) -def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: - @torch.compile - def impl(x: Tensor, w: Tensor): - assert x.is_contiguous() and w.is_contiguous() - x_f8 = x.div(x_s).to(torch.float8_e4m3fn) - w_f8 = w.div(w_s).to(torch.float8_e4m3fn) - out = torch._scaled_mm( - x_f8, - w_f8.T, - out_dtype=torch.bfloat16, - scale_a=x.new_tensor(x_s, dtype=torch.float32), - scale_b=x.new_tensor(w_s, dtype=torch.float32), - use_fast_accum=True, - ) - return out, x_f8, w_f8 - - return impl(x, w) - -@mm_op.register_fake -def _(x: Tensor, w: Tensor, *_): - assert x.ndim == w.ndim == 2 - assert x.shape[1] == w.shape[1] - assert x.device == w.device - assert x.is_contiguous() and w.is_contiguous() - return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) - -@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) -def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: - @torch.compile - def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): - assert grad.is_contiguous() - x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) - w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) - grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) - grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) - grad_x = torch._scaled_mm( - grad_f8, - w_f8.T.contiguous().T, - out_dtype=torch.bfloat16, - scale_a=grad_inv_s, - scale_b=w_inv_s, - use_fast_accum=False, - ) - # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) - grad_w = torch._scaled_mm( - x_f8.T.contiguous(), - grad_f8.T.contiguous().T, - out_dtype=torch.float32, - scale_a=x_inv_s, - scale_b=grad_inv_s, - use_fast_accum=False, - ).T - return grad_x, grad_w - - return impl(g, x_f8, w_f8) - -@mm_backward_op.register_fake -def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): - return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) - -def backward(ctx, grad_out: Tensor, *_): - x_f8, w_f8 = ctx.saved_tensors - x_s, w_s, grad_s = ctx.scales - grad_x, grad_w = torch.ops.nanogpt.mm_backward( - grad_out, x_f8, w_f8, x_s, w_s, grad_s - ) - return grad_x, grad_w, None, None, None - -def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): - *_, x_s, w_s, grad_s = inputs - _, x_f8, w_f8 = output - ctx.save_for_backward(x_f8, w_f8) - ctx.scales = x_s, w_s, grad_s - ctx.set_materialize_grads(False) - -mm_op.register_autograd(backward, setup_context=setup_context) - -# ----------------------------------------------------------------------------- -# Triton kernel for symmetric matrix multiplication by @byronxu99 - -def _get_autotune_configs(): - return [ - triton.Config( - { - "BLOCK_SIZE_M": bm, - "BLOCK_SIZE_N": bn, - "BLOCK_SIZE_K": bk, - "GROUP_SIZE_M": 8, - "LOWER_UPPER": 1, - }, - num_stages=stages, - num_warps=warps, - ) - for bm in [64, 128] - for bn in [64, 128, 256] - for bk in [64, 128] - for stages, warps in [(3, 4), (3, 8), (4, 4)] - if bm // bn <= 2 and bn // bm <= 2 - ] - -@triton.jit -def _pid_to_block( - pid, - M, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) - - # Map PID to a single matrix in batch - batch_idx = pid // (num_pid_m * num_pid_n) - pid = pid % (num_pid_m * num_pid_n) - - # Map PID to 2D grid of blocks - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) - - m_idx = pid_m * BLOCK_SIZE_M - n_idx = pid_n * BLOCK_SIZE_N - return batch_idx, m_idx, n_idx - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_1_kernel( - A_ptr, C_ptr, - M, K, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_1(A: torch.Tensor, out: torch.Tensor): - """ - Launch Triton kernel to compute C = A @ A.T - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert out.size(-2) == M, "Output matrix has incorrect shape" - assert out.size(-1) == M, "Output matrix has incorrect shape" - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_1_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - K=K, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - ) - return out - -@triton.autotune( - configs=_get_autotune_configs(), - key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], -) -@triton.jit -def ns_line_2_kernel( - A_ptr, C_ptr, - M, - a_stride_b, a_stride_r, a_stride_c, - c_stride_b, c_stride_r, c_stride_c, - alpha, beta, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - LOWER_UPPER: tl.constexpr, -): - # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A - # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels - pid = tl.program_id(axis=0) - batch_idx, m_idx, n_idx = _pid_to_block( - pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M - ) - - # Skip blocks that don't need to be computed - skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) - skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) - if skip_block_below_diag or skip_block_above_diag: - return - - # Index into one matrix of batch - A_ptr += batch_idx * a_stride_b - C_ptr += batch_idx * c_stride_b - - # Create pointer arrays for A and A.T - offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) - at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Accumulate over blocks of K - for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) - at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) - a_ptrs += BLOCK_SIZE_K * a_stride_c - at_ptrs += BLOCK_SIZE_K * a_stride_c - - # Load block of A to add (corresponds to the current block of C) - offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) - a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) - a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) - a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) - - # Apply alpha and beta - accumulator *= alpha - accumulator += a_add * beta - - out_dtype = C_ptr.dtype.element_ty - output = accumulator.to(out_dtype) - - # Store block of C - offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) - offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) - tl.store(c_ptrs, output, mask=c_mask) - - # Store block of C mirrored across the diagonal - c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) - c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) - tl.store(c_ptrs_t, output.T, mask=c_mask_t) - -def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): - """ - Launch Triton kernel to compute C = alpha * A @ A.T + beta * A - """ - assert A.ndim == 2 or A.ndim == 3 - M, K = A.shape[-2:] - assert M == K, "Input matrix must be square" - assert out.size(-2) == M - assert out.size(-1) == M - - batch_size = A.size(0) if A.ndim == 3 else 1 - input_batch_stride = A.stride(0) if A.ndim == 3 else 0 - output_batch_stride = out.stride(0) if out.ndim == 3 else 0 - - grid = lambda meta: ( - batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), - ) - ns_line_2_kernel[grid]( - A_ptr=A, - C_ptr=out, - M=M, - a_stride_b=input_batch_stride, - a_stride_r=A.stride(-2), - a_stride_c=A.stride(-1), - c_stride_b=output_batch_stride, - c_stride_r=out.stride(-2), - c_stride_c=out.stride(-1), - alpha=alpha, - beta=beta, - ) - return out - -@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower -def newton_schulz_triton(G: torch.Tensor): - a, b, c = (3.4445, -4.7750, 2.0315) - X = G.bfloat16() - if G.size(-2) > G.size(-1): - X = X.mT - - # Ensure spectral norm is at most 1 - X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) - - # Allocate buffers - X = X.contiguous() - A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) - B = torch.empty_like(A) - C = torch.empty_like(X) - - ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm - - # Perform the NS iterations - for _ in range(5): - ns_line_1(X, out=A) # A = X @ X.mT - ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A - ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X - X, C = C, X # Swap references to avoid unnecessary copies - - if G.size(-2) > G.size(-1): - X = X.mT - return X - -# ----------------------------------------------------------------------------- -# Muon optimizer - -class Muon(torch.optim.Optimizer): - """ - Muon - MomentUm Orthogonalized by Newton-schulz - - https://kellerjordan.github.io/posts/muon/ - - Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- - processing step, in which each 2D parameter's update is replaced with the nearest orthogonal - matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has - the advantage that it can be stably run in bfloat16 on the GPU. - - Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, - or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). - """ - def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): - defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - - @torch.no_grad() - def step(self): - # Efficient systems-wise implementation of step developed by @YouJiacheng, - # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, - # @ryanyang0, and @vagrawal. - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) - grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size - for base_i in range(0, len(params), world_size): - if base_i + rank < len(params): - grad = params[base_i + rank].grad - # This gives strange dynamo warnings - reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) - - idx = 0 - for group in self.param_groups: - params: list[Tensor] = group["params"] - params_pad = params + [torch.empty_like(params[-1])] * world_size - momentum = group["momentum"] - for base_i in range(0, len(params), world_size): - reduce_scatter_futures[idx].wait() - if base_i + rank < len(params): - p = params[base_i + rank] - grad = p.grad - eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) - eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) - state = self.state[p] - if len(state) == 0: - state["momentum_buffer"] = torch.zeros_like(grad) - momentum_buffer = state["momentum_buffer"] - p.mul_(1 - eff_weight_decay) - momentum_buffer.lerp_(grad, 1 - momentum) - grad = grad.lerp_(momentum_buffer, momentum) - v = newton_schulz_triton(grad) - p.add_(other=v, alpha=-eff_lr) - idx += 1 - all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -class DistAdam(torch.optim.Optimizer): - def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - params = list(params) - sizes = {p.shape for p in params} - # create one buffer per unique parameter-size - param_groups = [] - for size in sizes: - group_params = [p for p in params if p.shape == size] - param_groups.append(dict(params=group_params)) - super().__init__(param_groups, defaults) - # DistributedAdam implementation by @vagrawal - - @torch.compile - @torch.no_grad() - def step(self): - rank = dist.get_rank() - world_size = dist.get_world_size() - reduce_scatter_futures: list[torch.Future] = [] - all_gather_futures: list[torch.Future] = [] - grad_slices = [] - for group in self.param_groups: - params: list[Tensor] = group["params"] - for base_i in range(len(params)): - grad = params[base_i].grad - rank_size = grad.shape[0] // world_size - grad_slice = torch.empty_like(grad[:rank_size]) - reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) - grad_slices.append(grad_slice) - - idx = 0 - for group in self.param_groups: - beta1, beta2 = group['betas'] - eps = group['eps'] - wd = group['weight_decay'] - params = group['params'] - for base in range(len(params)): - reduce_scatter_futures[idx].wait() - p = params[base] - rank_size = p.shape[0] // world_size - p_slice = p[rank * rank_size:(rank + 1) * rank_size] - lr = group['lr'] * getattr(p, "lr_mul", 1.0) - state = self.state[p] - g_slice = grad_slices[idx] - # State init - if not state: - state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) - state['exp_avg'] = torch.zeros_like(p_slice) - state['exp_avg_sq'] = torch.zeros_like(p_slice) - exp_avg = state['exp_avg'] - exp_avg_sq = state['exp_avg_sq'] - state['step'] += 1 - t = state['step'] - # weight decay - if wd != 0: - eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) - p_slice.mul_(1 - eff_weight_decay) - # update running averages - exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) - # bias corrections - bias1 = 1 - beta1 ** t - bias2 = 1 - beta2 ** t - # compute step - denom = exp_avg_sq.sqrt().add_(eps) - step_size = lr * (torch.sqrt(bias2) / bias1) - update = exp_avg.div(denom).mul_(step_size) - p_slice.add_(other=update, alpha=-1.0) - idx += 1 - all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_gather_futures).wait() - -# ----------------------------------------------------------------------------- -# PyTorch nn.Module definitions for the model - -def norm(x: Tensor): - return F.rms_norm(x, (x.size(-1),)) - -class CastedLinear(nn.Linear): - def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): - super().__init__(in_features, out_features, bias=False) - self.use_fp8 = use_fp8 - self.x_s = x_s - self.w_s = w_s - self.grad_s = grad_s - - def reset_parameters(self) -> None: - std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) - bound = (3 ** 0.5) * std - with torch.no_grad(): - self.weight.uniform_(-bound, bound) - - def forward(self, x: Tensor): - if self.use_fp8 and self.training: - _x = x.flatten(0, -2) - out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] - return out.reshape(*x.shape[:-1], -1) - else: - return F.linear(x, self.weight.type_as(x)) - -class Rotary(nn.Module): - def __init__(self, dim: int, max_seq_len: int): - super().__init__() - # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) - angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) - angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) - t = torch.arange(max_seq_len, dtype=torch.float32) - theta = torch.einsum("i,j -> ij", t, angular_freq) - self.cos = nn.Buffer(theta.cos(), persistent=False) - self.sin = nn.Buffer(theta.sin(), persistent=False) - - def forward(self, x_BTHD: Tensor): - assert self.cos.size(0) >= x_BTHD.size(-3) - cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] - x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) - y1 = x1 * cos + x2 * sin - y2 = x1 * (-sin) + x2 * cos - return torch.cat((y1, y2), 3).type_as(x_BTHD) - -class CausalSelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - hdim = num_heads * head_dim - assert hdim == dim, "num_heads * head_dim must equal model_dim" - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng - # https://x.com/hi_tysam/status/1879699187107033311 - self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) - with torch.no_grad(): - self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights - self.qkvo_w[3].zero_() # init output weights to zero - self.rotary = Rotary(head_dim, max_seq_len) - # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun - # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 - self.attn_scale = 0.12 - - # sparse gated attention to enable context based no-op by @classiclarryd - self.attn_gate_dim = 12 - self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) - self.attn_gate.weight.detach().zero_() - - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): - B, T = x.size(0), x.size(1) # batch size, sequence length - - q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) - q, k = norm(q), norm(k) # QK norm @Grad62304977 - q, k = self.rotary(q), self.rotary(k) - if ve is not None: - v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 - else: # skip mid-layers token value embeddings by @YouJiacheng - v = lambdas[0] * v - - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal - y = y.view(B, T, self.num_heads, self.head_dim) - y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) - y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = F.linear(y, self.qkvo_w[3].type_as(y)) - return y - -class MLP(nn.Module): - def __init__(self, dim: int): - super().__init__() - hdim = 4 * dim - # make both matrices have the same shape because optimizer sorts params by shape - # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size - self.c_fc = nn.Parameter(torch.empty(dim, hdim)) - self.c_proj = nn.Parameter(torch.empty(dim, hdim)) - std = 0.5 * (dim ** -0.5) - bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng - with torch.no_grad(): - self.c_fc.uniform_(-bound, bound) - self.c_proj.zero_() # zero init suggested by @Grad62304977 - - def forward(self, x: Tensor): - x = F.linear(x, self.c_fc.T.type_as(x)) - x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = F.linear(x, self.c_proj.type_as(x)) - return x - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): - super().__init__() - # skip attention of blocks.7 (the 8th layer) by @YouJiacheng - self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None - self.mlp = MLP(dim) - - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): - x = lambdas[0] * x + lambdas[1] * x0 - if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) - x = x + self.mlp(norm(x)) - return x - -# ----------------------------------------------------------------------------- -# The main model - -def next_multiple_of_n(v: float | int, *, n: int): - return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) - -class GPT(nn.Module): - def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): - super().__init__() - vocab_size = next_multiple_of_n(vocab_size, n=128) - self.embed = nn.Embedding(vocab_size, model_dim) - # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 - # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 - self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) - self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) - # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. - # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - use_fp8 = not os.environ.get("DISABLE_FP8", False) - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) - self.lm_head.weight.detach().zero_() # @Grad62304977 - # Add learnable skip connection weights for decoder layers - assert num_layers % 2 == 0 - pad = (-num_layers * 5) % dist.get_world_size() - self.scalars = nn.Parameter(torch.cat([ - torch.ones(num_layers), # skip_weights - *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas - *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas - torch.ones(pad), - ])) - # set learning rates - for param in self.embed.parameters(): - param.lr_mul = 75. - for param in self.value_embeds.parameters(): - param.lr_mul = 75. - self.lm_head.weight.lr_mul = 1.0 - self.scalars.lr_mul = 5.0 - - - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 - - ve = [value_embed(input_seq) for value_embed in self.value_embeds] - # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure - ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] - assert len(ve) == len(self.blocks) - - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth - bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] - assert len(bm_sizes) == len(self.blocks) - - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 - - # U-net design by @brendanh0gan - skip_connections = [] - skip_weights = self.scalars[:(len(self.blocks) // 2)] - lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) - sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) - - n = len(self.blocks) // 2 - - for i in range(len(self.blocks)): - if i >= n: - x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) - if i < n: - skip_connections.append(x) - - x = norm(x) - logits = self.lm_head(x).float() - # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) - logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") - return loss - -# ----------------------------------------------------------------------------- -# Distributed data loader - -def _load_data_shard(file: Path): - header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 - assert header[0] == 20240520, "magic number mismatch in the data .bin file" - assert header[1] == 1, "unsupported version" - num_tokens = int(header[2]) # number of tokens (claimed) - with file.open("rb", buffering=0) as f: - tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng - f.seek(256 * 4) - nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng - assert nbytes == 2 * num_tokens, "number of tokens read does not match header" - return tokens - -class EOSBatchFinder: - # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard - self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") - starts = [[] for _ in range(self.world_size)] - idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 - for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance - self.i = idx - return starts, advance - - -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap - rank = dist.get_rank() if dist.is_initialized() else 0 - world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" - - files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - if not files: - raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") - - file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) - - while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - - if align_to_bos: - try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] - except StopIteration: - # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) - continue - - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] - else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) - - buf = tokens[start_pos_local: end_pos_local + 1] - - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) - - new_params = yield ( - _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) - ) - - pos += batch_span - - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len - - -# ----------------------------------------------------------------------------- -# int main - -@dataclass -class Hyperparameters: - # data - train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on - val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on - val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. - # optimization - num_iterations: int = 1695 # number of iterations to run - cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate - # evaluation and logging - run_id: str = str(uuid.uuid4()) - val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end - save_checkpoint: bool = False - # attention masking - bandwidth: int = 128 - ws_schedule: tuple = (3, 7, 11) - -args = Hyperparameters() - -data_path = os.environ.get("DATA_PATH", ".") -args.train_files = os.path.join(data_path, args.train_files) -args.val_files = os.path.join(data_path, args.val_files) - -# torchrun sets these env variables -rank = int(os.environ["RANK"]) -world_size = int(os.environ["WORLD_SIZE"]) -assert 8 % world_size == 0, "world_size must be a divisor of 8" -grad_accum_steps = 8 // world_size -assert torch.cuda.is_available() -device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) -torch.cuda.set_device(device) -dist.init_process_group(backend="nccl", device_id=device) -dist.barrier() -master_process = (rank == 0) # this process will do logging, checkpointing etc. - -# begin logging -logfile = None -if master_process: - run_id = args.run_id - os.makedirs("logs", exist_ok=True) - logfile = f"logs/{run_id}.txt" - print(logfile) -def print0(s, console=False): - if master_process: - with open(logfile, "a") as f: - if console: - print(s) - print(s, file=f) - -# begin by printing this file (the Python code) -print0(code) -print0("="*100) -# log information about the hardware/software environment this is running on -print0(f"Running Python {sys.version}") -print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") -print0(f"Running Triton version {triton.__version__}") - -def nvidia_smi(): - import subprocess # avoid top level import - return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout -print0(nvidia_smi()) -print0("="*100) - -model: nn.Module = GPT( - vocab_size=50257, - num_layers=12, - num_heads=6, - model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) -).cuda() -for m in model.modules(): - if isinstance(m, nn.Embedding): - m.bfloat16() -for param in model.parameters(): - dist.broadcast(param.detach(), 0) - -# collect the parameters to optimize -hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] -embed_params = [p for n, p in model.named_parameters() if "embed" in n] -scalar_params = [p for p in model.parameters() if p.ndim < 2] -head_params = [model.lm_head.weight] - -# init the optimizer(s) -# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence -# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 -optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) -optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) -optimizers = [optimizer1, optimizer2] -for opt in optimizers: - for group in opt.param_groups: - group["initial_lr"] = group["lr"] - -# learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training - assert 0 <= x < 1 - lr = 1.0 - if x >= 1 - args.cooldown_frac: - w = (1 - x) / args.cooldown_frac - lr = w * 1.0 + (1 - w) * 0.1 - ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] - -model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) - -######################################## -# Warmup kernels # -######################################## - -# Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 -initial_state = dict(model=copy.deepcopy(model.state_dict()), - optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() - for opt in optimizers: - opt.step() - model.zero_grad(set_to_none=True) -model.load_state_dict(initial_state["model"]) -for opt, opt_state in zip(optimizers, initial_state["optimizers"]): - opt.load_state_dict(opt_state) -del train_loader, initial_state - -######################################## -# Training and validation # -######################################## - -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) -training_time_ms = 0 -# start the clock -torch.cuda.synchronize() -t0 = time.perf_counter() -# begin training -train_steps = args.num_iterations -for step in range(train_steps + 1): - last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) - - # --------------- VALIDATION SECTION ----------------- - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): - # stop the clock - torch.cuda.synchronize() - training_time_ms += 1000 * (time.perf_counter() - t0) - model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) - val_loss = 0 - with torch.no_grad(): - for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) - val_loss /= val_steps - del val_loader - dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) - print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) - model.train() - # start the clock again - torch.cuda.synchronize() - t0 = time.perf_counter() - - if last_step: - if master_process and args.save_checkpoint: - log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) - os.makedirs(f"logs/{run_id}", exist_ok=True) - torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") - # the last step only has the validation loop, so break to avoid training - break - - # --------------- TRAINING SECTION ----------------- - for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() - # set optimization hyperparameters - for opt in optimizers: - for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr - for group in optimizer2.param_groups: - frac = min(step / 300, 1) # momentum warmup for muon - group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 - # step the optimizers - for opt in optimizers: - opt.step() - # null the gradients - model.zero_grad(set_to_none=True) - # logging - approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) - print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) - -print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " - f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() -==================================================================================================== -Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] -Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 -Running Triton version 3.4.0 -Wed Aug 27 03:47:47 2025 -+---------------------------------------------------------------------------------------+ -| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.6 | -|-----------------------------------------+----------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+======================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:00:0B.0 Off | Off | -| N/A 32C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:00:0C.0 Off | Off | -| N/A 36C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:00:0D.0 Off | Off | -| N/A 37C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:00:0E.0 Off | Off | -| N/A 32C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:00:0F.0 Off | Off | -| N/A 32C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:00:10.0 Off | Off | -| N/A 38C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:00:11.0 Off | Off | -| N/A 36C P0 113W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:00:12.0 Off | Off | -| N/A 33C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | -| | | Disabled | -+-----------------------------------------+----------------------+----------------------+ - -+---------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=======================================================================================| -+---------------------------------------------------------------------------------------+ - -==================================================================================================== -step:0/1695 val_loss:10.8258 train_time:0ms step_avg:0.03ms -step:1/1695 train_time:520ms step_avg:519.82ms -step:2/1695 train_time:546ms step_avg:272.89ms -step:3/1695 train_time:615ms step_avg:204.90ms -step:4/1695 train_time:707ms step_avg:176.66ms -step:5/1695 train_time:800ms step_avg:159.97ms -step:6/1695 train_time:892ms step_avg:148.74ms -step:7/1695 train_time:986ms step_avg:140.81ms -step:8/1695 train_time:1079ms step_avg:134.91ms -step:9/1695 train_time:1173ms step_avg:130.29ms -step:10/1695 train_time:1266ms step_avg:126.65ms -step:11/1695 train_time:1360ms step_avg:123.62ms -step:12/1695 train_time:1455ms step_avg:121.29ms -step:13/1695 train_time:1553ms step_avg:119.48ms -step:14/1695 train_time:1649ms step_avg:117.78ms -step:15/1695 train_time:1745ms step_avg:116.36ms -step:16/1695 train_time:1840ms step_avg:115.01ms -step:17/1695 train_time:1933ms step_avg:113.72ms -step:18/1695 train_time:2027ms step_avg:112.62ms -step:19/1695 train_time:2121ms step_avg:111.64ms -step:20/1695 train_time:2215ms step_avg:110.74ms -step:21/1695 train_time:2309ms step_avg:109.95ms -step:22/1695 train_time:2405ms step_avg:109.30ms -step:23/1695 train_time:2501ms step_avg:108.76ms -step:24/1695 train_time:2596ms step_avg:108.18ms -step:25/1695 train_time:2692ms step_avg:107.66ms -step:26/1695 train_time:2787ms step_avg:107.19ms -step:27/1695 train_time:2882ms step_avg:106.75ms -step:28/1695 train_time:2976ms step_avg:106.30ms -step:29/1695 train_time:3070ms step_avg:105.87ms -step:30/1695 train_time:3164ms step_avg:105.48ms -step:31/1695 train_time:3258ms step_avg:105.10ms -step:32/1695 train_time:3352ms step_avg:104.76ms -step:33/1695 train_time:3448ms step_avg:104.49ms -step:34/1695 train_time:3546ms step_avg:104.28ms -step:35/1695 train_time:3641ms step_avg:104.04ms -step:36/1695 train_time:3735ms step_avg:103.76ms -step:37/1695 train_time:3830ms step_avg:103.51ms -step:38/1695 train_time:3925ms step_avg:103.28ms -step:39/1695 train_time:4019ms step_avg:103.05ms -step:40/1695 train_time:4113ms step_avg:102.83ms -step:41/1695 train_time:4208ms step_avg:102.63ms -step:42/1695 train_time:4302ms step_avg:102.43ms -step:43/1695 train_time:4396ms step_avg:102.22ms -step:44/1695 train_time:4490ms step_avg:102.06ms -step:45/1695 train_time:4587ms step_avg:101.92ms -step:46/1695 train_time:4682ms step_avg:101.78ms -step:47/1695 train_time:4776ms step_avg:101.62ms -step:48/1695 train_time:4871ms step_avg:101.47ms -step:49/1695 train_time:4965ms step_avg:101.33ms -step:50/1695 train_time:5061ms step_avg:101.22ms -step:51/1695 train_time:5154ms step_avg:101.05ms -step:52/1695 train_time:5248ms step_avg:100.93ms -step:53/1695 train_time:5343ms step_avg:100.82ms -step:54/1695 train_time:5439ms step_avg:100.73ms -step:55/1695 train_time:5533ms step_avg:100.59ms -step:56/1695 train_time:5628ms step_avg:100.50ms -step:57/1695 train_time:5723ms step_avg:100.41ms -step:58/1695 train_time:5819ms step_avg:100.33ms -step:59/1695 train_time:5913ms step_avg:100.22ms -step:60/1695 train_time:6008ms step_avg:100.13ms -step:61/1695 train_time:6101ms step_avg:100.02ms -step:62/1695 train_time:6195ms step_avg:99.91ms -step:63/1695 train_time:6289ms step_avg:99.82ms -step:64/1695 train_time:6384ms step_avg:99.75ms -step:65/1695 train_time:6479ms step_avg:99.68ms -step:66/1695 train_time:6573ms step_avg:99.59ms -step:67/1695 train_time:6668ms step_avg:99.52ms -step:68/1695 train_time:6762ms step_avg:99.45ms -step:69/1695 train_time:6856ms step_avg:99.37ms -step:70/1695 train_time:6950ms step_avg:99.29ms -step:71/1695 train_time:7044ms step_avg:99.21ms -step:72/1695 train_time:7138ms step_avg:99.14ms -step:73/1695 train_time:7232ms step_avg:99.07ms -step:74/1695 train_time:7327ms step_avg:99.02ms -step:75/1695 train_time:7423ms step_avg:98.97ms -step:76/1695 train_time:7518ms step_avg:98.92ms -step:77/1695 train_time:7613ms step_avg:98.87ms -step:78/1695 train_time:7709ms step_avg:98.83ms -step:79/1695 train_time:7803ms step_avg:98.78ms -step:80/1695 train_time:7897ms step_avg:98.72ms -step:81/1695 train_time:7991ms step_avg:98.65ms -step:82/1695 train_time:8085ms step_avg:98.60ms -step:83/1695 train_time:8180ms step_avg:98.56ms -step:84/1695 train_time:8274ms step_avg:98.50ms -step:85/1695 train_time:8368ms step_avg:98.45ms -step:86/1695 train_time:8463ms step_avg:98.40ms -step:87/1695 train_time:8556ms step_avg:98.35ms -step:88/1695 train_time:8651ms step_avg:98.31ms -step:89/1695 train_time:8747ms step_avg:98.29ms -step:90/1695 train_time:8843ms step_avg:98.26ms -step:91/1695 train_time:8938ms step_avg:98.22ms -step:92/1695 train_time:9031ms step_avg:98.16ms -step:93/1695 train_time:9125ms step_avg:98.12ms -step:94/1695 train_time:9220ms step_avg:98.08ms -step:95/1695 train_time:9314ms step_avg:98.04ms -step:96/1695 train_time:9408ms step_avg:98.00ms -step:97/1695 train_time:9503ms step_avg:97.97ms -step:98/1695 train_time:9597ms step_avg:97.93ms -step:99/1695 train_time:9691ms step_avg:97.89ms -step:100/1695 train_time:9787ms step_avg:97.87ms -step:101/1695 train_time:9881ms step_avg:97.84ms -step:102/1695 train_time:9975ms step_avg:97.80ms -step:103/1695 train_time:10069ms step_avg:97.76ms -step:104/1695 train_time:10164ms step_avg:97.74ms -step:105/1695 train_time:10258ms step_avg:97.70ms -step:106/1695 train_time:10352ms step_avg:97.66ms -step:107/1695 train_time:10447ms step_avg:97.64ms -step:108/1695 train_time:10543ms step_avg:97.62ms -step:109/1695 train_time:10637ms step_avg:97.59ms -step:110/1695 train_time:10732ms step_avg:97.56ms -step:111/1695 train_time:10827ms step_avg:97.54ms -step:112/1695 train_time:10922ms step_avg:97.52ms -step:113/1695 train_time:11016ms step_avg:97.49ms -step:114/1695 train_time:11110ms step_avg:97.46ms -step:115/1695 train_time:11205ms step_avg:97.44ms -step:116/1695 train_time:11299ms step_avg:97.41ms -step:117/1695 train_time:11393ms step_avg:97.38ms -step:118/1695 train_time:11488ms step_avg:97.35ms -step:119/1695 train_time:11583ms step_avg:97.33ms -step:120/1695 train_time:11677ms step_avg:97.31ms -step:121/1695 train_time:11771ms step_avg:97.28ms -step:122/1695 train_time:11867ms step_avg:97.27ms -step:123/1695 train_time:11961ms step_avg:97.24ms -step:124/1695 train_time:12054ms step_avg:97.21ms -step:125/1695 train_time:12149ms step_avg:97.19ms -step:125/1695 val_loss:4.3107 train_time:12241ms step_avg:97.93ms -step:126/1695 train_time:12268ms step_avg:97.36ms -step:127/1695 train_time:12344ms step_avg:97.20ms -step:128/1695 train_time:12448ms step_avg:97.25ms -step:129/1695 train_time:12544ms step_avg:97.24ms -step:130/1695 train_time:12638ms step_avg:97.21ms -step:131/1695 train_time:12730ms step_avg:97.18ms -step:132/1695 train_time:12824ms step_avg:97.15ms -step:133/1695 train_time:12918ms step_avg:97.13ms -step:134/1695 train_time:13011ms step_avg:97.10ms -step:135/1695 train_time:13105ms step_avg:97.07ms -step:136/1695 train_time:13198ms step_avg:97.04ms -step:137/1695 train_time:13292ms step_avg:97.02ms -step:138/1695 train_time:13389ms step_avg:97.02ms -step:139/1695 train_time:13485ms step_avg:97.01ms -step:140/1695 train_time:13580ms step_avg:97.00ms -step:141/1695 train_time:13675ms step_avg:96.98ms -step:142/1695 train_time:13770ms step_avg:96.97ms -step:143/1695 train_time:13862ms step_avg:96.94ms -step:144/1695 train_time:13955ms step_avg:96.91ms -step:145/1695 train_time:14049ms step_avg:96.89ms -step:146/1695 train_time:14144ms step_avg:96.87ms -step:147/1695 train_time:14238ms step_avg:96.86ms -step:148/1695 train_time:14333ms step_avg:96.84ms -step:149/1695 train_time:14427ms step_avg:96.83ms -step:150/1695 train_time:14524ms step_avg:96.82ms -step:151/1695 train_time:14617ms step_avg:96.80ms -step:152/1695 train_time:14711ms step_avg:96.78ms -step:153/1695 train_time:14805ms step_avg:96.76ms -step:154/1695 train_time:14899ms step_avg:96.75ms -step:155/1695 train_time:14992ms step_avg:96.72ms -step:156/1695 train_time:15085ms step_avg:96.70ms -step:157/1695 train_time:15180ms step_avg:96.69ms -step:158/1695 train_time:15274ms step_avg:96.67ms -step:159/1695 train_time:15368ms step_avg:96.66ms -step:160/1695 train_time:15463ms step_avg:96.65ms -step:161/1695 train_time:15558ms step_avg:96.63ms -step:162/1695 train_time:15652ms step_avg:96.62ms -step:163/1695 train_time:15746ms step_avg:96.60ms -step:164/1695 train_time:15841ms step_avg:96.59ms -step:165/1695 train_time:15935ms step_avg:96.57ms -step:166/1695 train_time:16029ms step_avg:96.56ms -step:167/1695 train_time:16124ms step_avg:96.55ms -step:168/1695 train_time:16217ms step_avg:96.53ms -step:169/1695 train_time:16313ms step_avg:96.52ms -step:170/1695 train_time:16406ms step_avg:96.51ms -step:171/1695 train_time:16501ms step_avg:96.50ms -step:172/1695 train_time:16595ms step_avg:96.48ms -step:173/1695 train_time:16938ms step_avg:97.91ms -step:174/1695 train_time:17053ms step_avg:98.00ms -step:175/1695 train_time:17146ms step_avg:97.98ms -step:176/1695 train_time:17239ms step_avg:97.95ms -step:177/1695 train_time:17332ms step_avg:97.92ms -step:178/1695 train_time:17426ms step_avg:97.90ms -step:179/1695 train_time:17518ms step_avg:97.87ms -step:180/1695 train_time:17611ms step_avg:97.84ms -step:181/1695 train_time:17705ms step_avg:97.82ms -step:182/1695 train_time:17798ms step_avg:97.79ms -step:183/1695 train_time:17897ms step_avg:97.80ms -step:184/1695 train_time:17993ms step_avg:97.79ms -step:185/1695 train_time:18087ms step_avg:97.77ms -step:186/1695 train_time:18181ms step_avg:97.75ms -step:187/1695 train_time:18275ms step_avg:97.73ms -step:188/1695 train_time:18369ms step_avg:97.71ms -step:189/1695 train_time:18462ms step_avg:97.68ms -step:190/1695 train_time:18555ms step_avg:97.66ms -step:191/1695 train_time:18648ms step_avg:97.63ms -step:192/1695 train_time:18743ms step_avg:97.62ms -step:193/1695 train_time:18839ms step_avg:97.61ms -step:194/1695 train_time:18934ms step_avg:97.60ms -step:195/1695 train_time:19028ms step_avg:97.58ms -step:196/1695 train_time:19122ms step_avg:97.56ms -step:197/1695 train_time:19216ms step_avg:97.55ms -step:198/1695 train_time:19310ms step_avg:97.53ms -step:199/1695 train_time:19404ms step_avg:97.51ms -step:200/1695 train_time:19498ms step_avg:97.49ms -step:201/1695 train_time:19591ms step_avg:97.47ms -step:202/1695 train_time:19684ms step_avg:97.45ms -step:203/1695 train_time:19779ms step_avg:97.43ms -step:204/1695 train_time:19873ms step_avg:97.42ms -step:205/1695 train_time:19967ms step_avg:97.40ms -step:206/1695 train_time:20062ms step_avg:97.39ms -step:207/1695 train_time:20157ms step_avg:97.38ms -step:208/1695 train_time:20251ms step_avg:97.36ms -step:209/1695 train_time:20345ms step_avg:97.34ms -step:210/1695 train_time:20439ms step_avg:97.33ms -step:211/1695 train_time:20532ms step_avg:97.31ms -step:212/1695 train_time:20625ms step_avg:97.29ms -step:213/1695 train_time:20719ms step_avg:97.27ms -step:214/1695 train_time:20814ms step_avg:97.26ms -step:215/1695 train_time:20908ms step_avg:97.25ms -step:216/1695 train_time:21002ms step_avg:97.23ms -step:217/1695 train_time:21097ms step_avg:97.22ms -step:218/1695 train_time:21191ms step_avg:97.21ms -step:219/1695 train_time:21285ms step_avg:97.19ms -step:220/1695 train_time:21380ms step_avg:97.18ms -step:221/1695 train_time:21474ms step_avg:97.17ms -step:222/1695 train_time:21568ms step_avg:97.15ms -step:223/1695 train_time:21664ms step_avg:97.15ms -step:224/1695 train_time:21756ms step_avg:97.13ms -step:225/1695 train_time:21850ms step_avg:97.11ms -step:226/1695 train_time:21945ms step_avg:97.10ms -step:227/1695 train_time:22040ms step_avg:97.09ms -step:228/1695 train_time:22134ms step_avg:97.08ms -step:229/1695 train_time:22228ms step_avg:97.06ms -step:230/1695 train_time:22322ms step_avg:97.05ms -step:231/1695 train_time:22417ms step_avg:97.04ms -step:232/1695 train_time:22510ms step_avg:97.02ms -step:233/1695 train_time:22604ms step_avg:97.01ms -step:234/1695 train_time:22699ms step_avg:97.00ms -step:235/1695 train_time:22792ms step_avg:96.99ms -step:236/1695 train_time:22886ms step_avg:96.97ms -step:237/1695 train_time:22981ms step_avg:96.96ms -step:238/1695 train_time:23075ms step_avg:96.95ms -step:239/1695 train_time:23170ms step_avg:96.94ms -step:240/1695 train_time:23264ms step_avg:96.93ms -step:241/1695 train_time:23358ms step_avg:96.92ms -step:242/1695 train_time:23451ms step_avg:96.91ms -step:243/1695 train_time:23546ms step_avg:96.90ms -step:244/1695 train_time:23641ms step_avg:96.89ms -step:245/1695 train_time:23735ms step_avg:96.88ms -step:246/1695 train_time:23829ms step_avg:96.87ms -step:247/1695 train_time:23923ms step_avg:96.85ms -step:248/1695 train_time:24018ms step_avg:96.85ms -step:249/1695 train_time:24112ms step_avg:96.83ms -step:250/1695 train_time:24206ms step_avg:96.82ms -step:250/1695 val_loss:3.9776 train_time:24298ms step_avg:97.19ms -step:251/1695 train_time:24328ms step_avg:96.92ms -step:252/1695 train_time:24399ms step_avg:96.82ms -step:253/1695 train_time:24498ms step_avg:96.83ms -step:254/1695 train_time:24592ms step_avg:96.82ms -step:255/1695 train_time:24685ms step_avg:96.81ms -step:256/1695 train_time:24778ms step_avg:96.79ms -step:257/1695 train_time:24870ms step_avg:96.77ms -step:258/1695 train_time:24964ms step_avg:96.76ms -step:259/1695 train_time:25057ms step_avg:96.74ms -step:260/1695 train_time:25150ms step_avg:96.73ms -step:261/1695 train_time:25243ms step_avg:96.72ms -step:262/1695 train_time:25338ms step_avg:96.71ms -step:263/1695 train_time:25434ms step_avg:96.71ms -step:264/1695 train_time:25529ms step_avg:96.70ms -step:265/1695 train_time:25625ms step_avg:96.70ms -step:266/1695 train_time:25719ms step_avg:96.69ms -step:267/1695 train_time:25812ms step_avg:96.68ms -step:268/1695 train_time:25905ms step_avg:96.66ms -step:269/1695 train_time:25998ms step_avg:96.65ms -step:270/1695 train_time:26091ms step_avg:96.63ms -step:271/1695 train_time:26185ms step_avg:96.62ms -step:272/1695 train_time:26279ms step_avg:96.61ms -step:273/1695 train_time:26373ms step_avg:96.60ms -step:274/1695 train_time:26468ms step_avg:96.60ms -step:275/1695 train_time:26563ms step_avg:96.59ms -step:276/1695 train_time:26658ms step_avg:96.59ms -step:277/1695 train_time:26751ms step_avg:96.57ms -step:278/1695 train_time:26845ms step_avg:96.56ms -step:279/1695 train_time:26939ms step_avg:96.56ms -step:280/1695 train_time:27032ms step_avg:96.54ms -step:281/1695 train_time:27125ms step_avg:96.53ms -step:282/1695 train_time:27219ms step_avg:96.52ms -step:283/1695 train_time:27313ms step_avg:96.51ms -step:284/1695 train_time:27407ms step_avg:96.50ms -step:285/1695 train_time:27502ms step_avg:96.50ms -step:286/1695 train_time:27597ms step_avg:96.49ms -step:287/1695 train_time:27691ms step_avg:96.49ms -step:288/1695 train_time:27785ms step_avg:96.48ms -step:289/1695 train_time:27879ms step_avg:96.47ms -step:290/1695 train_time:27973ms step_avg:96.46ms -step:291/1695 train_time:28067ms step_avg:96.45ms -step:292/1695 train_time:28160ms step_avg:96.44ms -step:293/1695 train_time:28254ms step_avg:96.43ms -step:294/1695 train_time:28347ms step_avg:96.42ms -step:295/1695 train_time:28442ms step_avg:96.41ms -step:296/1695 train_time:28537ms step_avg:96.41ms -step:297/1695 train_time:28632ms step_avg:96.40ms -step:298/1695 train_time:28726ms step_avg:96.40ms -step:299/1695 train_time:28821ms step_avg:96.39ms -step:300/1695 train_time:28915ms step_avg:96.38ms -step:301/1695 train_time:29008ms step_avg:96.37ms -step:302/1695 train_time:29101ms step_avg:96.36ms -step:303/1695 train_time:29195ms step_avg:96.35ms -step:304/1695 train_time:29289ms step_avg:96.34ms -step:305/1695 train_time:29383ms step_avg:96.34ms -step:306/1695 train_time:29478ms step_avg:96.33ms -step:307/1695 train_time:29572ms step_avg:96.32ms -step:308/1695 train_time:29667ms step_avg:96.32ms -step:309/1695 train_time:29761ms step_avg:96.31ms -step:310/1695 train_time:29855ms step_avg:96.31ms -step:311/1695 train_time:29949ms step_avg:96.30ms -step:312/1695 train_time:30044ms step_avg:96.29ms -step:313/1695 train_time:30138ms step_avg:96.29ms -step:314/1695 train_time:30231ms step_avg:96.28ms -step:315/1695 train_time:30325ms step_avg:96.27ms -step:316/1695 train_time:30419ms step_avg:96.26ms -step:317/1695 train_time:30513ms step_avg:96.26ms -step:318/1695 train_time:30607ms step_avg:96.25ms -step:319/1695 train_time:30702ms step_avg:96.24ms -step:320/1695 train_time:30796ms step_avg:96.24ms -step:321/1695 train_time:30890ms step_avg:96.23ms -step:322/1695 train_time:30984ms step_avg:96.22ms -step:323/1695 train_time:31078ms step_avg:96.22ms -step:324/1695 train_time:31171ms step_avg:96.21ms -step:325/1695 train_time:31265ms step_avg:96.20ms -step:326/1695 train_time:31359ms step_avg:96.19ms -step:327/1695 train_time:31452ms step_avg:96.18ms -step:328/1695 train_time:31545ms step_avg:96.17ms -step:329/1695 train_time:31640ms step_avg:96.17ms -step:330/1695 train_time:31735ms step_avg:96.17ms -step:331/1695 train_time:31829ms step_avg:96.16ms -step:332/1695 train_time:31923ms step_avg:96.15ms -step:333/1695 train_time:32016ms step_avg:96.14ms -step:334/1695 train_time:32109ms step_avg:96.14ms -step:335/1695 train_time:32203ms step_avg:96.13ms -step:336/1695 train_time:32298ms step_avg:96.12ms -step:337/1695 train_time:32392ms step_avg:96.12ms -step:338/1695 train_time:32486ms step_avg:96.11ms -step:339/1695 train_time:32580ms step_avg:96.11ms -step:340/1695 train_time:32674ms step_avg:96.10ms -step:341/1695 train_time:32767ms step_avg:96.09ms -step:342/1695 train_time:32862ms step_avg:96.09ms -step:343/1695 train_time:32957ms step_avg:96.08ms -step:344/1695 train_time:33050ms step_avg:96.08ms -step:345/1695 train_time:33394ms step_avg:96.79ms -step:346/1695 train_time:33465ms step_avg:96.72ms -step:347/1695 train_time:33557ms step_avg:96.71ms -step:348/1695 train_time:33650ms step_avg:96.70ms -step:349/1695 train_time:33744ms step_avg:96.69ms -step:350/1695 train_time:33837ms step_avg:96.68ms -step:351/1695 train_time:33930ms step_avg:96.67ms -step:352/1695 train_time:34023ms step_avg:96.66ms -step:353/1695 train_time:34116ms step_avg:96.65ms -step:354/1695 train_time:34209ms step_avg:96.64ms -step:355/1695 train_time:34305ms step_avg:96.63ms -step:356/1695 train_time:34401ms step_avg:96.63ms -step:357/1695 train_time:34497ms step_avg:96.63ms -step:358/1695 train_time:34591ms step_avg:96.62ms -step:359/1695 train_time:34685ms step_avg:96.61ms -step:360/1695 train_time:34779ms step_avg:96.61ms -step:361/1695 train_time:34871ms step_avg:96.60ms -step:362/1695 train_time:34965ms step_avg:96.59ms -step:363/1695 train_time:35060ms step_avg:96.58ms -step:364/1695 train_time:35153ms step_avg:96.57ms -step:365/1695 train_time:35246ms step_avg:96.56ms -step:366/1695 train_time:35341ms step_avg:96.56ms -step:367/1695 train_time:35436ms step_avg:96.56ms -step:368/1695 train_time:35530ms step_avg:96.55ms -step:369/1695 train_time:35625ms step_avg:96.54ms -step:370/1695 train_time:35719ms step_avg:96.54ms -step:371/1695 train_time:35812ms step_avg:96.53ms -step:372/1695 train_time:35905ms step_avg:96.52ms -step:373/1695 train_time:35999ms step_avg:96.51ms -step:374/1695 train_time:36092ms step_avg:96.50ms -step:375/1695 train_time:36185ms step_avg:96.49ms -step:375/1695 val_loss:3.8206 train_time:36277ms step_avg:96.74ms -step:376/1695 train_time:36303ms step_avg:96.55ms -step:377/1695 train_time:36376ms step_avg:96.49ms -step:378/1695 train_time:36477ms step_avg:96.50ms -step:379/1695 train_time:36572ms step_avg:96.50ms -step:380/1695 train_time:36667ms step_avg:96.49ms -step:381/1695 train_time:36760ms step_avg:96.48ms -step:382/1695 train_time:36852ms step_avg:96.47ms -step:383/1695 train_time:36946ms step_avg:96.46ms -step:384/1695 train_time:37039ms step_avg:96.45ms -step:385/1695 train_time:37131ms step_avg:96.44ms -step:386/1695 train_time:37224ms step_avg:96.44ms -step:387/1695 train_time:37319ms step_avg:96.43ms -step:388/1695 train_time:37415ms step_avg:96.43ms -step:389/1695 train_time:37511ms step_avg:96.43ms -step:390/1695 train_time:37607ms step_avg:96.43ms -step:391/1695 train_time:37701ms step_avg:96.42ms -step:392/1695 train_time:37794ms step_avg:96.41ms -step:393/1695 train_time:37888ms step_avg:96.41ms -step:394/1695 train_time:37982ms step_avg:96.40ms -step:395/1695 train_time:38075ms step_avg:96.39ms -step:396/1695 train_time:38168ms step_avg:96.39ms -step:397/1695 train_time:38263ms step_avg:96.38ms -step:398/1695 train_time:38358ms step_avg:96.38ms -step:399/1695 train_time:38452ms step_avg:96.37ms -step:400/1695 train_time:38548ms step_avg:96.37ms -step:401/1695 train_time:38643ms step_avg:96.37ms -step:402/1695 train_time:38737ms step_avg:96.36ms -step:403/1695 train_time:38830ms step_avg:96.35ms -step:404/1695 train_time:38924ms step_avg:96.35ms -step:405/1695 train_time:39017ms step_avg:96.34ms -step:406/1695 train_time:39110ms step_avg:96.33ms -step:407/1695 train_time:39203ms step_avg:96.32ms -step:408/1695 train_time:39296ms step_avg:96.31ms -step:409/1695 train_time:39391ms step_avg:96.31ms -step:410/1695 train_time:39486ms step_avg:96.31ms -step:411/1695 train_time:39580ms step_avg:96.30ms -step:412/1695 train_time:39673ms step_avg:96.29ms -step:413/1695 train_time:39768ms step_avg:96.29ms -step:414/1695 train_time:39863ms step_avg:96.29ms -step:415/1695 train_time:39957ms step_avg:96.28ms -step:416/1695 train_time:40050ms step_avg:96.27ms -step:417/1695 train_time:40144ms step_avg:96.27ms -step:418/1695 train_time:40237ms step_avg:96.26ms -step:419/1695 train_time:40331ms step_avg:96.25ms -step:420/1695 train_time:40425ms step_avg:96.25ms -step:421/1695 train_time:40519ms step_avg:96.25ms -step:422/1695 train_time:40613ms step_avg:96.24ms -step:423/1695 train_time:40707ms step_avg:96.23ms -step:424/1695 train_time:40802ms step_avg:96.23ms -step:425/1695 train_time:40895ms step_avg:96.22ms -step:426/1695 train_time:40990ms step_avg:96.22ms -step:427/1695 train_time:41085ms step_avg:96.22ms -step:428/1695 train_time:41178ms step_avg:96.21ms -step:429/1695 train_time:41272ms step_avg:96.20ms -step:430/1695 train_time:41365ms step_avg:96.20ms -step:431/1695 train_time:41459ms step_avg:96.19ms -step:432/1695 train_time:41553ms step_avg:96.19ms -step:433/1695 train_time:41648ms step_avg:96.18ms -step:434/1695 train_time:41743ms step_avg:96.18ms -step:435/1695 train_time:41837ms step_avg:96.18ms -step:436/1695 train_time:41930ms step_avg:96.17ms -step:437/1695 train_time:42025ms step_avg:96.17ms -step:438/1695 train_time:42119ms step_avg:96.16ms -step:439/1695 train_time:42212ms step_avg:96.15ms -step:440/1695 train_time:42305ms step_avg:96.15ms -step:441/1695 train_time:42399ms step_avg:96.14ms -step:442/1695 train_time:42492ms step_avg:96.14ms -step:443/1695 train_time:42587ms step_avg:96.13ms -step:444/1695 train_time:42680ms step_avg:96.13ms -step:445/1695 train_time:42774ms step_avg:96.12ms -step:446/1695 train_time:42868ms step_avg:96.12ms -step:447/1695 train_time:42963ms step_avg:96.11ms -step:448/1695 train_time:43057ms step_avg:96.11ms -step:449/1695 train_time:43150ms step_avg:96.10ms -step:450/1695 train_time:43245ms step_avg:96.10ms -step:451/1695 train_time:43338ms step_avg:96.09ms -step:452/1695 train_time:43432ms step_avg:96.09ms -step:453/1695 train_time:43526ms step_avg:96.08ms -step:454/1695 train_time:43621ms step_avg:96.08ms -step:455/1695 train_time:43715ms step_avg:96.08ms -step:456/1695 train_time:43809ms step_avg:96.07ms -step:457/1695 train_time:43903ms step_avg:96.07ms -step:458/1695 train_time:43997ms step_avg:96.06ms -step:459/1695 train_time:44091ms step_avg:96.06ms -step:460/1695 train_time:44185ms step_avg:96.05ms -step:461/1695 train_time:44278ms step_avg:96.05ms -step:462/1695 train_time:44372ms step_avg:96.04ms -step:463/1695 train_time:44466ms step_avg:96.04ms -step:464/1695 train_time:44560ms step_avg:96.03ms -step:465/1695 train_time:44654ms step_avg:96.03ms -step:466/1695 train_time:44748ms step_avg:96.03ms -step:467/1695 train_time:44842ms step_avg:96.02ms -step:468/1695 train_time:44935ms step_avg:96.02ms -step:469/1695 train_time:45031ms step_avg:96.01ms -step:470/1695 train_time:45125ms step_avg:96.01ms -step:471/1695 train_time:45218ms step_avg:96.00ms -step:472/1695 train_time:45312ms step_avg:96.00ms -step:473/1695 train_time:45406ms step_avg:96.00ms -step:474/1695 train_time:45499ms step_avg:95.99ms -step:475/1695 train_time:45594ms step_avg:95.99ms -step:476/1695 train_time:45689ms step_avg:95.99ms -step:477/1695 train_time:45784ms step_avg:95.98ms -step:478/1695 train_time:45877ms step_avg:95.98ms -step:479/1695 train_time:45971ms step_avg:95.97ms -step:480/1695 train_time:46065ms step_avg:95.97ms -step:481/1695 train_time:46159ms step_avg:95.96ms -step:482/1695 train_time:46252ms step_avg:95.96ms -step:483/1695 train_time:46347ms step_avg:95.96ms -step:484/1695 train_time:46442ms step_avg:95.95ms -step:485/1695 train_time:46535ms step_avg:95.95ms -step:486/1695 train_time:46629ms step_avg:95.95ms -step:487/1695 train_time:46723ms step_avg:95.94ms -step:488/1695 train_time:46818ms step_avg:95.94ms -step:489/1695 train_time:46911ms step_avg:95.93ms -step:490/1695 train_time:47004ms step_avg:95.93ms -step:491/1695 train_time:47098ms step_avg:95.92ms -step:492/1695 train_time:47191ms step_avg:95.92ms -step:493/1695 train_time:47286ms step_avg:95.91ms -step:494/1695 train_time:47381ms step_avg:95.91ms -step:495/1695 train_time:47474ms step_avg:95.91ms -step:496/1695 train_time:47568ms step_avg:95.90ms -step:497/1695 train_time:47663ms step_avg:95.90ms -step:498/1695 train_time:47756ms step_avg:95.90ms -step:499/1695 train_time:47851ms step_avg:95.89ms -step:500/1695 train_time:47944ms step_avg:95.89ms -step:500/1695 val_loss:3.7169 train_time:48035ms step_avg:96.07ms -step:501/1695 train_time:48064ms step_avg:95.94ms -step:502/1695 train_time:48136ms step_avg:95.89ms -step:503/1695 train_time:48238ms step_avg:95.90ms -step:504/1695 train_time:48335ms step_avg:95.90ms -step:505/1695 train_time:48428ms step_avg:95.90ms -step:506/1695 train_time:48521ms step_avg:95.89ms -step:507/1695 train_time:48614ms step_avg:95.89ms -step:508/1695 train_time:48707ms step_avg:95.88ms -step:509/1695 train_time:48800ms step_avg:95.87ms -step:510/1695 train_time:48893ms step_avg:95.87ms -step:511/1695 train_time:48986ms step_avg:95.86ms -step:512/1695 train_time:49080ms step_avg:95.86ms -step:513/1695 train_time:49177ms step_avg:95.86ms -step:514/1695 train_time:49275ms step_avg:95.87ms -step:515/1695 train_time:49371ms step_avg:95.87ms -step:516/1695 train_time:49465ms step_avg:95.86ms -step:517/1695 train_time:49558ms step_avg:95.86ms -step:518/1695 train_time:49652ms step_avg:95.85ms -step:519/1695 train_time:49984ms step_avg:96.31ms -step:520/1695 train_time:50159ms step_avg:96.46ms -step:521/1695 train_time:50250ms step_avg:96.45ms -step:522/1695 train_time:50343ms step_avg:96.44ms -step:523/1695 train_time:50436ms step_avg:96.44ms -step:524/1695 train_time:50529ms step_avg:96.43ms -step:525/1695 train_time:50621ms step_avg:96.42ms -step:526/1695 train_time:50714ms step_avg:96.41ms -step:527/1695 train_time:50806ms step_avg:96.41ms -step:528/1695 train_time:50898ms step_avg:96.40ms -step:529/1695 train_time:50992ms step_avg:96.39ms -step:530/1695 train_time:51091ms step_avg:96.40ms -step:531/1695 train_time:51188ms step_avg:96.40ms -step:532/1695 train_time:51282ms step_avg:96.40ms -step:533/1695 train_time:51376ms step_avg:96.39ms -step:534/1695 train_time:51470ms step_avg:96.39ms -step:535/1695 train_time:51564ms step_avg:96.38ms -step:536/1695 train_time:51656ms step_avg:96.37ms -step:537/1695 train_time:51750ms step_avg:96.37ms -step:538/1695 train_time:51842ms step_avg:96.36ms -step:539/1695 train_time:51935ms step_avg:96.35ms -step:540/1695 train_time:52032ms step_avg:96.36ms -step:541/1695 train_time:52126ms step_avg:96.35ms -step:542/1695 train_time:52220ms step_avg:96.35ms -step:543/1695 train_time:52316ms step_avg:96.35ms -step:544/1695 train_time:52411ms step_avg:96.34ms -step:545/1695 train_time:52505ms step_avg:96.34ms -step:546/1695 train_time:52598ms step_avg:96.33ms -step:547/1695 train_time:52691ms step_avg:96.33ms -step:548/1695 train_time:52784ms step_avg:96.32ms -step:549/1695 train_time:52877ms step_avg:96.32ms -step:550/1695 train_time:52971ms step_avg:96.31ms -step:551/1695 train_time:53066ms step_avg:96.31ms -step:552/1695 train_time:53161ms step_avg:96.31ms -step:553/1695 train_time:53255ms step_avg:96.30ms -step:554/1695 train_time:53352ms step_avg:96.30ms -step:555/1695 train_time:53446ms step_avg:96.30ms -step:556/1695 train_time:53539ms step_avg:96.29ms -step:557/1695 train_time:53633ms step_avg:96.29ms -step:558/1695 train_time:53727ms step_avg:96.29ms -step:559/1695 train_time:53820ms step_avg:96.28ms -step:560/1695 train_time:53914ms step_avg:96.27ms -step:561/1695 train_time:54009ms step_avg:96.27ms -step:562/1695 train_time:54102ms step_avg:96.27ms -step:563/1695 train_time:54196ms step_avg:96.26ms -step:564/1695 train_time:54291ms step_avg:96.26ms -step:565/1695 train_time:54385ms step_avg:96.26ms -step:566/1695 train_time:54479ms step_avg:96.25ms -step:567/1695 train_time:54573ms step_avg:96.25ms -step:568/1695 train_time:54670ms step_avg:96.25ms -step:569/1695 train_time:54765ms step_avg:96.25ms -step:570/1695 train_time:54859ms step_avg:96.24ms -step:571/1695 train_time:54955ms step_avg:96.24ms -step:572/1695 train_time:55052ms step_avg:96.24ms -step:573/1695 train_time:55149ms step_avg:96.25ms -step:574/1695 train_time:55245ms step_avg:96.25ms -step:575/1695 train_time:55344ms step_avg:96.25ms -step:576/1695 train_time:55437ms step_avg:96.25ms -step:577/1695 train_time:55534ms step_avg:96.25ms -step:578/1695 train_time:55631ms step_avg:96.25ms -step:579/1695 train_time:55727ms step_avg:96.25ms -step:580/1695 train_time:55823ms step_avg:96.25ms -step:581/1695 train_time:55918ms step_avg:96.24ms -step:582/1695 train_time:56015ms step_avg:96.25ms -step:583/1695 train_time:56111ms step_avg:96.25ms -step:584/1695 train_time:56209ms step_avg:96.25ms -step:585/1695 train_time:56306ms step_avg:96.25ms -step:586/1695 train_time:56402ms step_avg:96.25ms -step:587/1695 train_time:56497ms step_avg:96.25ms -step:588/1695 train_time:56593ms step_avg:96.25ms -step:589/1695 train_time:56689ms step_avg:96.25ms -step:590/1695 train_time:56786ms step_avg:96.25ms -step:591/1695 train_time:56882ms step_avg:96.25ms -step:592/1695 train_time:56978ms step_avg:96.25ms -step:593/1695 train_time:57074ms step_avg:96.25ms -step:594/1695 train_time:57172ms step_avg:96.25ms -step:595/1695 train_time:57269ms step_avg:96.25ms -step:596/1695 train_time:57365ms step_avg:96.25ms -step:597/1695 train_time:57461ms step_avg:96.25ms -step:598/1695 train_time:57555ms step_avg:96.25ms -step:599/1695 train_time:57652ms step_avg:96.25ms -step:600/1695 train_time:57749ms step_avg:96.25ms -step:601/1695 train_time:57846ms step_avg:96.25ms -step:602/1695 train_time:57942ms step_avg:96.25ms -step:603/1695 train_time:58037ms step_avg:96.25ms -step:604/1695 train_time:58133ms step_avg:96.25ms -step:605/1695 train_time:58229ms step_avg:96.25ms -step:606/1695 train_time:58326ms step_avg:96.25ms -step:607/1695 train_time:58421ms step_avg:96.25ms -step:608/1695 train_time:58516ms step_avg:96.24ms -step:609/1695 train_time:58613ms step_avg:96.25ms -step:610/1695 train_time:58710ms step_avg:96.25ms -step:611/1695 train_time:58807ms step_avg:96.25ms -step:612/1695 train_time:58903ms step_avg:96.25ms -step:613/1695 train_time:58998ms step_avg:96.24ms -step:614/1695 train_time:59095ms step_avg:96.25ms -step:615/1695 train_time:59191ms step_avg:96.25ms -step:616/1695 train_time:59288ms step_avg:96.25ms -step:617/1695 train_time:59384ms step_avg:96.25ms -step:618/1695 train_time:59479ms step_avg:96.24ms -step:619/1695 train_time:59577ms step_avg:96.25ms -step:620/1695 train_time:59674ms step_avg:96.25ms -step:621/1695 train_time:59772ms step_avg:96.25ms -step:622/1695 train_time:59868ms step_avg:96.25ms -step:623/1695 train_time:59964ms step_avg:96.25ms -step:624/1695 train_time:60059ms step_avg:96.25ms -step:625/1695 train_time:60155ms step_avg:96.25ms -step:625/1695 val_loss:3.6208 train_time:60249ms step_avg:96.40ms -step:626/1695 train_time:60275ms step_avg:96.29ms -step:627/1695 train_time:60358ms step_avg:96.26ms -step:628/1695 train_time:60456ms step_avg:96.27ms -step:629/1695 train_time:60551ms step_avg:96.27ms -step:630/1695 train_time:60646ms step_avg:96.26ms -step:631/1695 train_time:60741ms step_avg:96.26ms -step:632/1695 train_time:60836ms step_avg:96.26ms -step:633/1695 train_time:60930ms step_avg:96.26ms -step:634/1695 train_time:61025ms step_avg:96.25ms -step:635/1695 train_time:61122ms step_avg:96.25ms -step:636/1695 train_time:61217ms step_avg:96.25ms -step:637/1695 train_time:61315ms step_avg:96.26ms -step:638/1695 train_time:61413ms step_avg:96.26ms -step:639/1695 train_time:61509ms step_avg:96.26ms -step:640/1695 train_time:61606ms step_avg:96.26ms -step:641/1695 train_time:61702ms step_avg:96.26ms -step:642/1695 train_time:61797ms step_avg:96.26ms -step:643/1695 train_time:61892ms step_avg:96.26ms -step:644/1695 train_time:61988ms step_avg:96.25ms -step:645/1695 train_time:62084ms step_avg:96.25ms -step:646/1695 train_time:62181ms step_avg:96.26ms -step:647/1695 train_time:62278ms step_avg:96.26ms -step:648/1695 train_time:62376ms step_avg:96.26ms -step:649/1695 train_time:62472ms step_avg:96.26ms -step:650/1695 train_time:62568ms step_avg:96.26ms -step:651/1695 train_time:62664ms step_avg:96.26ms -step:652/1695 train_time:62760ms step_avg:96.26ms -step:653/1695 train_time:62855ms step_avg:96.26ms -step:654/1695 train_time:62950ms step_avg:96.25ms -step:655/1695 train_time:63045ms step_avg:96.25ms -step:656/1695 train_time:63141ms step_avg:96.25ms -step:657/1695 train_time:63238ms step_avg:96.25ms -step:658/1695 train_time:63335ms step_avg:96.25ms -step:659/1695 train_time:63431ms step_avg:96.25ms -step:660/1695 train_time:63529ms step_avg:96.26ms -step:661/1695 train_time:63626ms step_avg:96.26ms -step:662/1695 train_time:63721ms step_avg:96.26ms -step:663/1695 train_time:63816ms step_avg:96.25ms -step:664/1695 train_time:63911ms step_avg:96.25ms -step:665/1695 train_time:64007ms step_avg:96.25ms -step:666/1695 train_time:64103ms step_avg:96.25ms -step:667/1695 train_time:64199ms step_avg:96.25ms -step:668/1695 train_time:64295ms step_avg:96.25ms -step:669/1695 train_time:64390ms step_avg:96.25ms -step:670/1695 train_time:64487ms step_avg:96.25ms -step:671/1695 train_time:64585ms step_avg:96.25ms -step:672/1695 train_time:64682ms step_avg:96.25ms -step:673/1695 train_time:64778ms step_avg:96.25ms -step:674/1695 train_time:64873ms step_avg:96.25ms -step:675/1695 train_time:64969ms step_avg:96.25ms -step:676/1695 train_time:65065ms step_avg:96.25ms -step:677/1695 train_time:65161ms step_avg:96.25ms -step:678/1695 train_time:65257ms step_avg:96.25ms -step:679/1695 train_time:65353ms step_avg:96.25ms -step:680/1695 train_time:65449ms step_avg:96.25ms -step:681/1695 train_time:65545ms step_avg:96.25ms -step:682/1695 train_time:65642ms step_avg:96.25ms -step:683/1695 train_time:65738ms step_avg:96.25ms -step:684/1695 train_time:65833ms step_avg:96.25ms -step:685/1695 train_time:65929ms step_avg:96.25ms -step:686/1695 train_time:66024ms step_avg:96.25ms -step:687/1695 train_time:66120ms step_avg:96.25ms -step:688/1695 train_time:66216ms step_avg:96.24ms -step:689/1695 train_time:66311ms step_avg:96.24ms -step:690/1695 train_time:66408ms step_avg:96.24ms -step:691/1695 train_time:66766ms step_avg:96.62ms -step:692/1695 train_time:66929ms step_avg:96.72ms -step:693/1695 train_time:67024ms step_avg:96.72ms -step:694/1695 train_time:67118ms step_avg:96.71ms -step:695/1695 train_time:67212ms step_avg:96.71ms -step:696/1695 train_time:67307ms step_avg:96.71ms -step:697/1695 train_time:67402ms step_avg:96.70ms -step:698/1695 train_time:67496ms step_avg:96.70ms -step:699/1695 train_time:67590ms step_avg:96.70ms -step:700/1695 train_time:67686ms step_avg:96.69ms -step:701/1695 train_time:67789ms step_avg:96.70ms -step:702/1695 train_time:67889ms step_avg:96.71ms -step:703/1695 train_time:67986ms step_avg:96.71ms -step:704/1695 train_time:68082ms step_avg:96.71ms -step:705/1695 train_time:68176ms step_avg:96.70ms -step:706/1695 train_time:68272ms step_avg:96.70ms -step:707/1695 train_time:68368ms step_avg:96.70ms -step:708/1695 train_time:68463ms step_avg:96.70ms -step:709/1695 train_time:68558ms step_avg:96.70ms -step:710/1695 train_time:68653ms step_avg:96.69ms -step:711/1695 train_time:68750ms step_avg:96.70ms -step:712/1695 train_time:68848ms step_avg:96.70ms -step:713/1695 train_time:68946ms step_avg:96.70ms -step:714/1695 train_time:69043ms step_avg:96.70ms -step:715/1695 train_time:69140ms step_avg:96.70ms -step:716/1695 train_time:69235ms step_avg:96.70ms -step:717/1695 train_time:69330ms step_avg:96.69ms -step:718/1695 train_time:69426ms step_avg:96.69ms -step:719/1695 train_time:69522ms step_avg:96.69ms -step:720/1695 train_time:69617ms step_avg:96.69ms -step:721/1695 train_time:69713ms step_avg:96.69ms -step:722/1695 train_time:69810ms step_avg:96.69ms -step:723/1695 train_time:69907ms step_avg:96.69ms -step:724/1695 train_time:70003ms step_avg:96.69ms -step:725/1695 train_time:70100ms step_avg:96.69ms -step:726/1695 train_time:70195ms step_avg:96.69ms -step:727/1695 train_time:70291ms step_avg:96.69ms -step:728/1695 train_time:70387ms step_avg:96.69ms -step:729/1695 train_time:70482ms step_avg:96.68ms -step:730/1695 train_time:70578ms step_avg:96.68ms -step:731/1695 train_time:70673ms step_avg:96.68ms -step:732/1695 train_time:70769ms step_avg:96.68ms -step:733/1695 train_time:70865ms step_avg:96.68ms -step:734/1695 train_time:70963ms step_avg:96.68ms -step:735/1695 train_time:71059ms step_avg:96.68ms -step:736/1695 train_time:71155ms step_avg:96.68ms -step:737/1695 train_time:71250ms step_avg:96.68ms -step:738/1695 train_time:71345ms step_avg:96.67ms -step:739/1695 train_time:71441ms step_avg:96.67ms -step:740/1695 train_time:71537ms step_avg:96.67ms -step:741/1695 train_time:71632ms step_avg:96.67ms -step:742/1695 train_time:71728ms step_avg:96.67ms -step:743/1695 train_time:71824ms step_avg:96.67ms -step:744/1695 train_time:71920ms step_avg:96.67ms -step:745/1695 train_time:72016ms step_avg:96.67ms -step:746/1695 train_time:72111ms step_avg:96.66ms -step:747/1695 train_time:72208ms step_avg:96.66ms -step:748/1695 train_time:72303ms step_avg:96.66ms -step:749/1695 train_time:72399ms step_avg:96.66ms -step:750/1695 train_time:72495ms step_avg:96.66ms -step:750/1695 val_loss:3.5658 train_time:72587ms step_avg:96.78ms -step:751/1695 train_time:72614ms step_avg:96.69ms -step:752/1695 train_time:72692ms step_avg:96.67ms -step:753/1695 train_time:72794ms step_avg:96.67ms -step:754/1695 train_time:72890ms step_avg:96.67ms -step:755/1695 train_time:72986ms step_avg:96.67ms -step:756/1695 train_time:73081ms step_avg:96.67ms -step:757/1695 train_time:73176ms step_avg:96.67ms -step:758/1695 train_time:73271ms step_avg:96.66ms -step:759/1695 train_time:73365ms step_avg:96.66ms -step:760/1695 train_time:73460ms step_avg:96.66ms -step:761/1695 train_time:73556ms step_avg:96.66ms -step:762/1695 train_time:73653ms step_avg:96.66ms -step:763/1695 train_time:73751ms step_avg:96.66ms -step:764/1695 train_time:73849ms step_avg:96.66ms -step:765/1695 train_time:73945ms step_avg:96.66ms -step:766/1695 train_time:74040ms step_avg:96.66ms -step:767/1695 train_time:74135ms step_avg:96.66ms -step:768/1695 train_time:74231ms step_avg:96.66ms -step:769/1695 train_time:74327ms step_avg:96.65ms -step:770/1695 train_time:74421ms step_avg:96.65ms -step:771/1695 train_time:74516ms step_avg:96.65ms -step:772/1695 train_time:74614ms step_avg:96.65ms -step:773/1695 train_time:74712ms step_avg:96.65ms -step:774/1695 train_time:74809ms step_avg:96.65ms -step:775/1695 train_time:74908ms step_avg:96.65ms -step:776/1695 train_time:75004ms step_avg:96.65ms -step:777/1695 train_time:75099ms step_avg:96.65ms -step:778/1695 train_time:75194ms step_avg:96.65ms -step:779/1695 train_time:75290ms step_avg:96.65ms -step:780/1695 train_time:75386ms step_avg:96.65ms -step:781/1695 train_time:75482ms step_avg:96.65ms -step:782/1695 train_time:75578ms step_avg:96.65ms -step:783/1695 train_time:75675ms step_avg:96.65ms -step:784/1695 train_time:75772ms step_avg:96.65ms -step:785/1695 train_time:75870ms step_avg:96.65ms -step:786/1695 train_time:75967ms step_avg:96.65ms -step:787/1695 train_time:76062ms step_avg:96.65ms -step:788/1695 train_time:76158ms step_avg:96.65ms -step:789/1695 train_time:76253ms step_avg:96.65ms -step:790/1695 train_time:76350ms step_avg:96.65ms -step:791/1695 train_time:76445ms step_avg:96.64ms -step:792/1695 train_time:76541ms step_avg:96.64ms -step:793/1695 train_time:76638ms step_avg:96.64ms -step:794/1695 train_time:76735ms step_avg:96.64ms -step:795/1695 train_time:76832ms step_avg:96.64ms -step:796/1695 train_time:76930ms step_avg:96.65ms -step:797/1695 train_time:77026ms step_avg:96.65ms -step:798/1695 train_time:77121ms step_avg:96.64ms -step:799/1695 train_time:77216ms step_avg:96.64ms -step:800/1695 train_time:77312ms step_avg:96.64ms -step:801/1695 train_time:77411ms step_avg:96.64ms -step:802/1695 train_time:77506ms step_avg:96.64ms -step:803/1695 train_time:77602ms step_avg:96.64ms -step:804/1695 train_time:77697ms step_avg:96.64ms -step:805/1695 train_time:77795ms step_avg:96.64ms -step:806/1695 train_time:77893ms step_avg:96.64ms -step:807/1695 train_time:77990ms step_avg:96.64ms -step:808/1695 train_time:78087ms step_avg:96.64ms -step:809/1695 train_time:78183ms step_avg:96.64ms -step:810/1695 train_time:78281ms step_avg:96.64ms -step:811/1695 train_time:78374ms step_avg:96.64ms -step:812/1695 train_time:78471ms step_avg:96.64ms -step:813/1695 train_time:78568ms step_avg:96.64ms -step:814/1695 train_time:78663ms step_avg:96.64ms -step:815/1695 train_time:78759ms step_avg:96.64ms -step:816/1695 train_time:78855ms step_avg:96.64ms -step:817/1695 train_time:78952ms step_avg:96.64ms -step:818/1695 train_time:79048ms step_avg:96.64ms -step:819/1695 train_time:79144ms step_avg:96.64ms -step:820/1695 train_time:79239ms step_avg:96.63ms -step:821/1695 train_time:79336ms step_avg:96.63ms -step:822/1695 train_time:79431ms step_avg:96.63ms -step:823/1695 train_time:79527ms step_avg:96.63ms -step:824/1695 train_time:79623ms step_avg:96.63ms -step:825/1695 train_time:79719ms step_avg:96.63ms -step:826/1695 train_time:79816ms step_avg:96.63ms -step:827/1695 train_time:79912ms step_avg:96.63ms -step:828/1695 train_time:80008ms step_avg:96.63ms -step:829/1695 train_time:80103ms step_avg:96.63ms -step:830/1695 train_time:80199ms step_avg:96.62ms -step:831/1695 train_time:80295ms step_avg:96.62ms -step:832/1695 train_time:80391ms step_avg:96.62ms -step:833/1695 train_time:80487ms step_avg:96.62ms -step:834/1695 train_time:80584ms step_avg:96.62ms -step:835/1695 train_time:80680ms step_avg:96.62ms -step:836/1695 train_time:80775ms step_avg:96.62ms -step:837/1695 train_time:80872ms step_avg:96.62ms -step:838/1695 train_time:80968ms step_avg:96.62ms -step:839/1695 train_time:81064ms step_avg:96.62ms -step:840/1695 train_time:81160ms step_avg:96.62ms -step:841/1695 train_time:81257ms step_avg:96.62ms -step:842/1695 train_time:81352ms step_avg:96.62ms -step:843/1695 train_time:81450ms step_avg:96.62ms -step:844/1695 train_time:81546ms step_avg:96.62ms -step:845/1695 train_time:81643ms step_avg:96.62ms -step:846/1695 train_time:81738ms step_avg:96.62ms -step:847/1695 train_time:81834ms step_avg:96.62ms -step:848/1695 train_time:81931ms step_avg:96.62ms -step:849/1695 train_time:82026ms step_avg:96.62ms -step:850/1695 train_time:82122ms step_avg:96.61ms -step:851/1695 train_time:82218ms step_avg:96.61ms -step:852/1695 train_time:82313ms step_avg:96.61ms -step:853/1695 train_time:82409ms step_avg:96.61ms -step:854/1695 train_time:82506ms step_avg:96.61ms -step:855/1695 train_time:82602ms step_avg:96.61ms -step:856/1695 train_time:82698ms step_avg:96.61ms -step:857/1695 train_time:82795ms step_avg:96.61ms -step:858/1695 train_time:82891ms step_avg:96.61ms -step:859/1695 train_time:82987ms step_avg:96.61ms -step:860/1695 train_time:83083ms step_avg:96.61ms -step:861/1695 train_time:83180ms step_avg:96.61ms -step:862/1695 train_time:83275ms step_avg:96.61ms -step:863/1695 train_time:83625ms step_avg:96.90ms -step:864/1695 train_time:83802ms step_avg:96.99ms -step:865/1695 train_time:83896ms step_avg:96.99ms -step:866/1695 train_time:83991ms step_avg:96.99ms -step:867/1695 train_time:84086ms step_avg:96.99ms -step:868/1695 train_time:84180ms step_avg:96.98ms -step:869/1695 train_time:84275ms step_avg:96.98ms -step:870/1695 train_time:84370ms step_avg:96.98ms -step:871/1695 train_time:84465ms step_avg:96.97ms -step:872/1695 train_time:84559ms step_avg:96.97ms -step:873/1695 train_time:84660ms step_avg:96.98ms -step:874/1695 train_time:84758ms step_avg:96.98ms -step:875/1695 train_time:84857ms step_avg:96.98ms -step:875/1695 val_loss:3.5251 train_time:84952ms step_avg:97.09ms -step:876/1695 train_time:84978ms step_avg:97.01ms -step:877/1695 train_time:85058ms step_avg:96.99ms -step:878/1695 train_time:85158ms step_avg:96.99ms -step:879/1695 train_time:85256ms step_avg:96.99ms -step:880/1695 train_time:85352ms step_avg:96.99ms -step:881/1695 train_time:85447ms step_avg:96.99ms -step:882/1695 train_time:85541ms step_avg:96.99ms -step:883/1695 train_time:85636ms step_avg:96.98ms -step:884/1695 train_time:85731ms step_avg:96.98ms -step:885/1695 train_time:85825ms step_avg:96.98ms -step:886/1695 train_time:85920ms step_avg:96.98ms -step:887/1695 train_time:86020ms step_avg:96.98ms -step:888/1695 train_time:86118ms step_avg:96.98ms -step:889/1695 train_time:86217ms step_avg:96.98ms -step:890/1695 train_time:86314ms step_avg:96.98ms -step:891/1695 train_time:86411ms step_avg:96.98ms -step:892/1695 train_time:86505ms step_avg:96.98ms -step:893/1695 train_time:86600ms step_avg:96.98ms -step:894/1695 train_time:86695ms step_avg:96.97ms -step:895/1695 train_time:86790ms step_avg:96.97ms -step:896/1695 train_time:86886ms step_avg:96.97ms -step:897/1695 train_time:86982ms step_avg:96.97ms -step:898/1695 train_time:87080ms step_avg:96.97ms -step:899/1695 train_time:87177ms step_avg:96.97ms -step:900/1695 train_time:87276ms step_avg:96.97ms -step:901/1695 train_time:87374ms step_avg:96.97ms -step:902/1695 train_time:87471ms step_avg:96.97ms -step:903/1695 train_time:87567ms step_avg:96.97ms -step:904/1695 train_time:87661ms step_avg:96.97ms -step:905/1695 train_time:87756ms step_avg:96.97ms -step:906/1695 train_time:87852ms step_avg:96.97ms -step:907/1695 train_time:87948ms step_avg:96.97ms -step:908/1695 train_time:88044ms step_avg:96.96ms -step:909/1695 train_time:88140ms step_avg:96.96ms -step:910/1695 train_time:88237ms step_avg:96.96ms -step:911/1695 train_time:88336ms step_avg:96.97ms -step:912/1695 train_time:88433ms step_avg:96.97ms -step:913/1695 train_time:88528ms step_avg:96.96ms -step:914/1695 train_time:88624ms step_avg:96.96ms -step:915/1695 train_time:88718ms step_avg:96.96ms -step:916/1695 train_time:88816ms step_avg:96.96ms -step:917/1695 train_time:88914ms step_avg:96.96ms -step:918/1695 train_time:89010ms step_avg:96.96ms -step:919/1695 train_time:89107ms step_avg:96.96ms -step:920/1695 train_time:89203ms step_avg:96.96ms -step:921/1695 train_time:89299ms step_avg:96.96ms -step:922/1695 train_time:89396ms step_avg:96.96ms -step:923/1695 train_time:89493ms step_avg:96.96ms -step:924/1695 train_time:89590ms step_avg:96.96ms -step:925/1695 train_time:89686ms step_avg:96.96ms -step:926/1695 train_time:89782ms step_avg:96.96ms -step:927/1695 train_time:89878ms step_avg:96.96ms -step:928/1695 train_time:89974ms step_avg:96.95ms -step:929/1695 train_time:90071ms step_avg:96.95ms -step:930/1695 train_time:90167ms step_avg:96.95ms -step:931/1695 train_time:90262ms step_avg:96.95ms -step:932/1695 train_time:90358ms step_avg:96.95ms -step:933/1695 train_time:90455ms step_avg:96.95ms -step:934/1695 train_time:90551ms step_avg:96.95ms -step:935/1695 train_time:90648ms step_avg:96.95ms -step:936/1695 train_time:90743ms step_avg:96.95ms -step:937/1695 train_time:90839ms step_avg:96.95ms -step:938/1695 train_time:90935ms step_avg:96.95ms -step:939/1695 train_time:91030ms step_avg:96.94ms -step:940/1695 train_time:91126ms step_avg:96.94ms -step:941/1695 train_time:91221ms step_avg:96.94ms -step:942/1695 train_time:91317ms step_avg:96.94ms -step:943/1695 train_time:91413ms step_avg:96.94ms -step:944/1695 train_time:91510ms step_avg:96.94ms -step:945/1695 train_time:91607ms step_avg:96.94ms -step:946/1695 train_time:91703ms step_avg:96.94ms -step:947/1695 train_time:91798ms step_avg:96.94ms -step:948/1695 train_time:91894ms step_avg:96.93ms -step:949/1695 train_time:91989ms step_avg:96.93ms -step:950/1695 train_time:92085ms step_avg:96.93ms -step:951/1695 train_time:92181ms step_avg:96.93ms -step:952/1695 train_time:92277ms step_avg:96.93ms -step:953/1695 train_time:92373ms step_avg:96.93ms -step:954/1695 train_time:92469ms step_avg:96.93ms -step:955/1695 train_time:92566ms step_avg:96.93ms -step:956/1695 train_time:92662ms step_avg:96.93ms -step:957/1695 train_time:92758ms step_avg:96.93ms -step:958/1695 train_time:92854ms step_avg:96.92ms -step:959/1695 train_time:92950ms step_avg:96.92ms -step:960/1695 train_time:93046ms step_avg:96.92ms -step:961/1695 train_time:93142ms step_avg:96.92ms -step:962/1695 train_time:93238ms step_avg:96.92ms -step:963/1695 train_time:93334ms step_avg:96.92ms -step:964/1695 train_time:93431ms step_avg:96.92ms -step:965/1695 train_time:93526ms step_avg:96.92ms -step:966/1695 train_time:93621ms step_avg:96.92ms -step:967/1695 train_time:93718ms step_avg:96.92ms -step:968/1695 train_time:93815ms step_avg:96.92ms -step:969/1695 train_time:93911ms step_avg:96.92ms -step:970/1695 train_time:94008ms step_avg:96.92ms -step:971/1695 train_time:94104ms step_avg:96.91ms -step:972/1695 train_time:94200ms step_avg:96.91ms -step:973/1695 train_time:94295ms step_avg:96.91ms -step:974/1695 train_time:94392ms step_avg:96.91ms -step:975/1695 train_time:94489ms step_avg:96.91ms -step:976/1695 train_time:94585ms step_avg:96.91ms -step:977/1695 train_time:94681ms step_avg:96.91ms -step:978/1695 train_time:94778ms step_avg:96.91ms -step:979/1695 train_time:94874ms step_avg:96.91ms -step:980/1695 train_time:94970ms step_avg:96.91ms -step:981/1695 train_time:95066ms step_avg:96.91ms -step:982/1695 train_time:95161ms step_avg:96.91ms -step:983/1695 train_time:95257ms step_avg:96.90ms -step:984/1695 train_time:95353ms step_avg:96.90ms -step:985/1695 train_time:95450ms step_avg:96.90ms -step:986/1695 train_time:95547ms step_avg:96.90ms -step:987/1695 train_time:95644ms step_avg:96.90ms -step:988/1695 train_time:95739ms step_avg:96.90ms -step:989/1695 train_time:95835ms step_avg:96.90ms -step:990/1695 train_time:95932ms step_avg:96.90ms -step:991/1695 train_time:96028ms step_avg:96.90ms -step:992/1695 train_time:96123ms step_avg:96.90ms -step:993/1695 train_time:96218ms step_avg:96.90ms -step:994/1695 train_time:96314ms step_avg:96.90ms -step:995/1695 train_time:96411ms step_avg:96.90ms -step:996/1695 train_time:96508ms step_avg:96.90ms -step:997/1695 train_time:96605ms step_avg:96.90ms -step:998/1695 train_time:96701ms step_avg:96.89ms -step:999/1695 train_time:96796ms step_avg:96.89ms -step:1000/1695 train_time:96893ms step_avg:96.89ms -step:1000/1695 val_loss:3.4830 train_time:96988ms step_avg:96.99ms -step:1001/1695 train_time:97014ms step_avg:96.92ms -step:1002/1695 train_time:97096ms step_avg:96.90ms -step:1003/1695 train_time:97195ms step_avg:96.90ms -step:1004/1695 train_time:97291ms step_avg:96.90ms -step:1005/1695 train_time:97387ms step_avg:96.90ms -step:1006/1695 train_time:97482ms step_avg:96.90ms -step:1007/1695 train_time:97577ms step_avg:96.90ms -step:1008/1695 train_time:97673ms step_avg:96.90ms -step:1009/1695 train_time:97768ms step_avg:96.90ms -step:1010/1695 train_time:97862ms step_avg:96.89ms -step:1011/1695 train_time:97959ms step_avg:96.89ms -step:1012/1695 train_time:98058ms step_avg:96.90ms -step:1013/1695 train_time:98157ms step_avg:96.90ms -step:1014/1695 train_time:98256ms step_avg:96.90ms -step:1015/1695 train_time:98354ms step_avg:96.90ms -step:1016/1695 train_time:98450ms step_avg:96.90ms -step:1017/1695 train_time:98545ms step_avg:96.90ms -step:1018/1695 train_time:98640ms step_avg:96.90ms -step:1019/1695 train_time:98736ms step_avg:96.89ms -step:1020/1695 train_time:98831ms step_avg:96.89ms -step:1021/1695 train_time:98927ms step_avg:96.89ms -step:1022/1695 train_time:99025ms step_avg:96.89ms -step:1023/1695 train_time:99121ms step_avg:96.89ms -step:1024/1695 train_time:99218ms step_avg:96.89ms -step:1025/1695 train_time:99316ms step_avg:96.89ms -step:1026/1695 train_time:99413ms step_avg:96.89ms -step:1027/1695 train_time:99510ms step_avg:96.89ms -step:1028/1695 train_time:99605ms step_avg:96.89ms -step:1029/1695 train_time:99700ms step_avg:96.89ms -step:1030/1695 train_time:99795ms step_avg:96.89ms -step:1031/1695 train_time:99891ms step_avg:96.89ms -step:1032/1695 train_time:99987ms step_avg:96.89ms -step:1033/1695 train_time:100085ms step_avg:96.89ms -step:1034/1695 train_time:100181ms step_avg:96.89ms -step:1035/1695 train_time:100277ms step_avg:96.89ms -step:1036/1695 train_time:100616ms step_avg:97.12ms -step:1037/1695 train_time:100781ms step_avg:97.19ms -step:1038/1695 train_time:100876ms step_avg:97.18ms -step:1039/1695 train_time:100971ms step_avg:97.18ms -step:1040/1695 train_time:101066ms step_avg:97.18ms -step:1041/1695 train_time:101161ms step_avg:97.18ms -step:1042/1695 train_time:101256ms step_avg:97.17ms -step:1043/1695 train_time:101350ms step_avg:97.17ms -step:1044/1695 train_time:101444ms step_avg:97.17ms -step:1045/1695 train_time:101539ms step_avg:97.17ms -step:1046/1695 train_time:101641ms step_avg:97.17ms -step:1047/1695 train_time:101740ms step_avg:97.17ms -step:1048/1695 train_time:101838ms step_avg:97.17ms -step:1049/1695 train_time:101934ms step_avg:97.17ms -step:1050/1695 train_time:102031ms step_avg:97.17ms -step:1051/1695 train_time:102126ms step_avg:97.17ms -step:1052/1695 train_time:102220ms step_avg:97.17ms -step:1053/1695 train_time:102315ms step_avg:97.17ms -step:1054/1695 train_time:102411ms step_avg:97.16ms -step:1055/1695 train_time:102506ms step_avg:97.16ms -step:1056/1695 train_time:102603ms step_avg:97.16ms -step:1057/1695 train_time:102700ms step_avg:97.16ms -step:1058/1695 train_time:102797ms step_avg:97.16ms -step:1059/1695 train_time:102894ms step_avg:97.16ms -step:1060/1695 train_time:102991ms step_avg:97.16ms -step:1061/1695 train_time:103087ms step_avg:97.16ms -step:1062/1695 train_time:103182ms step_avg:97.16ms -step:1063/1695 train_time:103277ms step_avg:97.16ms -step:1064/1695 train_time:103373ms step_avg:97.16ms -step:1065/1695 train_time:103469ms step_avg:97.15ms -step:1066/1695 train_time:103565ms step_avg:97.15ms -step:1067/1695 train_time:103661ms step_avg:97.15ms -step:1068/1695 train_time:103756ms step_avg:97.15ms -step:1069/1695 train_time:103854ms step_avg:97.15ms -step:1070/1695 train_time:103951ms step_avg:97.15ms -step:1071/1695 train_time:104048ms step_avg:97.15ms -step:1072/1695 train_time:104144ms step_avg:97.15ms -step:1073/1695 train_time:104239ms step_avg:97.15ms -step:1074/1695 train_time:104335ms step_avg:97.15ms -step:1075/1695 train_time:104431ms step_avg:97.14ms -step:1076/1695 train_time:104527ms step_avg:97.14ms -step:1077/1695 train_time:104623ms step_avg:97.14ms -step:1078/1695 train_time:104719ms step_avg:97.14ms -step:1079/1695 train_time:104816ms step_avg:97.14ms -step:1080/1695 train_time:104912ms step_avg:97.14ms -step:1081/1695 train_time:105009ms step_avg:97.14ms -step:1082/1695 train_time:105105ms step_avg:97.14ms -step:1083/1695 train_time:105200ms step_avg:97.14ms -step:1084/1695 train_time:105296ms step_avg:97.14ms -step:1085/1695 train_time:105391ms step_avg:97.13ms -step:1086/1695 train_time:105487ms step_avg:97.13ms -step:1087/1695 train_time:105583ms step_avg:97.13ms -step:1088/1695 train_time:105678ms step_avg:97.13ms -step:1089/1695 train_time:105775ms step_avg:97.13ms -step:1090/1695 train_time:105871ms step_avg:97.13ms -step:1091/1695 train_time:105967ms step_avg:97.13ms -step:1092/1695 train_time:106063ms step_avg:97.13ms -step:1093/1695 train_time:106158ms step_avg:97.13ms -step:1094/1695 train_time:106254ms step_avg:97.12ms -step:1095/1695 train_time:106350ms step_avg:97.12ms -step:1096/1695 train_time:106446ms step_avg:97.12ms -step:1097/1695 train_time:106541ms step_avg:97.12ms -step:1098/1695 train_time:106636ms step_avg:97.12ms -step:1099/1695 train_time:106732ms step_avg:97.12ms -step:1100/1695 train_time:106829ms step_avg:97.12ms -step:1101/1695 train_time:106925ms step_avg:97.12ms -step:1102/1695 train_time:107021ms step_avg:97.12ms -step:1103/1695 train_time:107117ms step_avg:97.11ms -step:1104/1695 train_time:107213ms step_avg:97.11ms -step:1105/1695 train_time:107309ms step_avg:97.11ms -step:1106/1695 train_time:107405ms step_avg:97.11ms -step:1107/1695 train_time:107501ms step_avg:97.11ms -step:1108/1695 train_time:107597ms step_avg:97.11ms -step:1109/1695 train_time:107693ms step_avg:97.11ms -step:1110/1695 train_time:107790ms step_avg:97.11ms -step:1111/1695 train_time:107886ms step_avg:97.11ms -step:1112/1695 train_time:107982ms step_avg:97.11ms -step:1113/1695 train_time:108078ms step_avg:97.11ms -step:1114/1695 train_time:108174ms step_avg:97.10ms -step:1115/1695 train_time:108272ms step_avg:97.10ms -step:1116/1695 train_time:108369ms step_avg:97.10ms -step:1117/1695 train_time:108463ms step_avg:97.10ms -step:1118/1695 train_time:108558ms step_avg:97.10ms -step:1119/1695 train_time:108653ms step_avg:97.10ms -step:1120/1695 train_time:108749ms step_avg:97.10ms -step:1121/1695 train_time:108846ms step_avg:97.10ms -step:1122/1695 train_time:108942ms step_avg:97.10ms -step:1123/1695 train_time:109037ms step_avg:97.09ms -step:1124/1695 train_time:109133ms step_avg:97.09ms -step:1125/1695 train_time:109229ms step_avg:97.09ms -step:1125/1695 val_loss:3.4364 train_time:109322ms step_avg:97.18ms -step:1126/1695 train_time:109349ms step_avg:97.11ms -step:1127/1695 train_time:109426ms step_avg:97.10ms -step:1128/1695 train_time:109523ms step_avg:97.09ms -step:1129/1695 train_time:109619ms step_avg:97.09ms -step:1130/1695 train_time:109715ms step_avg:97.09ms -step:1131/1695 train_time:109810ms step_avg:97.09ms -step:1132/1695 train_time:109905ms step_avg:97.09ms -step:1133/1695 train_time:110001ms step_avg:97.09ms -step:1134/1695 train_time:110098ms step_avg:97.09ms -step:1135/1695 train_time:110195ms step_avg:97.09ms -step:1136/1695 train_time:110294ms step_avg:97.09ms -step:1137/1695 train_time:110397ms step_avg:97.09ms -step:1138/1695 train_time:110496ms step_avg:97.10ms -step:1139/1695 train_time:110596ms step_avg:97.10ms -step:1140/1695 train_time:110695ms step_avg:97.10ms -step:1141/1695 train_time:110793ms step_avg:97.10ms -step:1142/1695 train_time:110891ms step_avg:97.10ms -step:1143/1695 train_time:110988ms step_avg:97.10ms -step:1144/1695 train_time:111085ms step_avg:97.10ms -step:1145/1695 train_time:111182ms step_avg:97.10ms -step:1146/1695 train_time:111280ms step_avg:97.10ms -step:1147/1695 train_time:111379ms step_avg:97.10ms -step:1148/1695 train_time:111478ms step_avg:97.11ms -step:1149/1695 train_time:111576ms step_avg:97.11ms -step:1150/1695 train_time:111674ms step_avg:97.11ms -step:1151/1695 train_time:111773ms step_avg:97.11ms -step:1152/1695 train_time:111871ms step_avg:97.11ms -step:1153/1695 train_time:111969ms step_avg:97.11ms -step:1154/1695 train_time:112065ms step_avg:97.11ms -step:1155/1695 train_time:112162ms step_avg:97.11ms -step:1156/1695 train_time:112260ms step_avg:97.11ms -step:1157/1695 train_time:112358ms step_avg:97.11ms -step:1158/1695 train_time:112458ms step_avg:97.11ms -step:1159/1695 train_time:112557ms step_avg:97.12ms -step:1160/1695 train_time:112656ms step_avg:97.12ms -step:1161/1695 train_time:112755ms step_avg:97.12ms -step:1162/1695 train_time:112852ms step_avg:97.12ms -step:1163/1695 train_time:112951ms step_avg:97.12ms -step:1164/1695 train_time:113048ms step_avg:97.12ms -step:1165/1695 train_time:113144ms step_avg:97.12ms -step:1166/1695 train_time:113242ms step_avg:97.12ms -step:1167/1695 train_time:113340ms step_avg:97.12ms -step:1168/1695 train_time:113437ms step_avg:97.12ms -step:1169/1695 train_time:113536ms step_avg:97.12ms -step:1170/1695 train_time:113634ms step_avg:97.12ms -step:1171/1695 train_time:113733ms step_avg:97.12ms -step:1172/1695 train_time:113832ms step_avg:97.13ms -step:1173/1695 train_time:113931ms step_avg:97.13ms -step:1174/1695 train_time:114029ms step_avg:97.13ms -step:1175/1695 train_time:114127ms step_avg:97.13ms -step:1176/1695 train_time:114224ms step_avg:97.13ms -step:1177/1695 train_time:114322ms step_avg:97.13ms -step:1178/1695 train_time:114419ms step_avg:97.13ms -step:1179/1695 train_time:114517ms step_avg:97.13ms -step:1180/1695 train_time:114615ms step_avg:97.13ms -step:1181/1695 train_time:114713ms step_avg:97.13ms -step:1182/1695 train_time:114810ms step_avg:97.13ms -step:1183/1695 train_time:114908ms step_avg:97.13ms -step:1184/1695 train_time:115005ms step_avg:97.13ms -step:1185/1695 train_time:115103ms step_avg:97.13ms -step:1186/1695 train_time:115200ms step_avg:97.13ms -step:1187/1695 train_time:115298ms step_avg:97.13ms -step:1188/1695 train_time:115397ms step_avg:97.14ms -step:1189/1695 train_time:115495ms step_avg:97.14ms -step:1190/1695 train_time:115593ms step_avg:97.14ms -step:1191/1695 train_time:115691ms step_avg:97.14ms -step:1192/1695 train_time:115789ms step_avg:97.14ms -step:1193/1695 train_time:115887ms step_avg:97.14ms -step:1194/1695 train_time:115984ms step_avg:97.14ms -step:1195/1695 train_time:116082ms step_avg:97.14ms -step:1196/1695 train_time:116180ms step_avg:97.14ms -step:1197/1695 train_time:116279ms step_avg:97.14ms -step:1198/1695 train_time:116377ms step_avg:97.14ms -step:1199/1695 train_time:116476ms step_avg:97.14ms -step:1200/1695 train_time:116574ms step_avg:97.14ms -step:1201/1695 train_time:116672ms step_avg:97.15ms -step:1202/1695 train_time:116770ms step_avg:97.15ms -step:1203/1695 train_time:116868ms step_avg:97.15ms -step:1204/1695 train_time:116965ms step_avg:97.15ms -step:1205/1695 train_time:117062ms step_avg:97.15ms -step:1206/1695 train_time:117160ms step_avg:97.15ms -step:1207/1695 train_time:117259ms step_avg:97.15ms -step:1208/1695 train_time:117635ms step_avg:97.38ms -step:1209/1695 train_time:117773ms step_avg:97.41ms -step:1210/1695 train_time:117868ms step_avg:97.41ms -step:1211/1695 train_time:117964ms step_avg:97.41ms -step:1212/1695 train_time:118061ms step_avg:97.41ms -step:1213/1695 train_time:118157ms step_avg:97.41ms -step:1214/1695 train_time:118253ms step_avg:97.41ms -step:1215/1695 train_time:118350ms step_avg:97.41ms -step:1216/1695 train_time:118446ms step_avg:97.41ms -step:1217/1695 train_time:118542ms step_avg:97.41ms -step:1218/1695 train_time:118646ms step_avg:97.41ms -step:1219/1695 train_time:118750ms step_avg:97.42ms -step:1220/1695 train_time:118849ms step_avg:97.42ms -step:1221/1695 train_time:118947ms step_avg:97.42ms -step:1222/1695 train_time:119044ms step_avg:97.42ms -step:1223/1695 train_time:119140ms step_avg:97.42ms -step:1224/1695 train_time:119238ms step_avg:97.42ms -step:1225/1695 train_time:119335ms step_avg:97.42ms -step:1226/1695 train_time:119432ms step_avg:97.42ms -step:1227/1695 train_time:119529ms step_avg:97.42ms -step:1228/1695 train_time:119626ms step_avg:97.42ms -step:1229/1695 train_time:119726ms step_avg:97.42ms -step:1230/1695 train_time:119824ms step_avg:97.42ms -step:1231/1695 train_time:119923ms step_avg:97.42ms -step:1232/1695 train_time:120021ms step_avg:97.42ms -step:1233/1695 train_time:120119ms step_avg:97.42ms -step:1234/1695 train_time:120216ms step_avg:97.42ms -step:1235/1695 train_time:120313ms step_avg:97.42ms -step:1236/1695 train_time:120410ms step_avg:97.42ms -step:1237/1695 train_time:120506ms step_avg:97.42ms -step:1238/1695 train_time:120604ms step_avg:97.42ms -step:1239/1695 train_time:120703ms step_avg:97.42ms -step:1240/1695 train_time:120801ms step_avg:97.42ms -step:1241/1695 train_time:120900ms step_avg:97.42ms -step:1242/1695 train_time:120998ms step_avg:97.42ms -step:1243/1695 train_time:121096ms step_avg:97.42ms -step:1244/1695 train_time:121194ms step_avg:97.42ms -step:1245/1695 train_time:121292ms step_avg:97.42ms -step:1246/1695 train_time:121388ms step_avg:97.42ms -step:1247/1695 train_time:121485ms step_avg:97.42ms -step:1248/1695 train_time:121583ms step_avg:97.42ms -step:1249/1695 train_time:121682ms step_avg:97.42ms -step:1250/1695 train_time:121780ms step_avg:97.42ms -step:1250/1695 val_loss:3.3889 train_time:121876ms step_avg:97.50ms -step:1251/1695 train_time:121915ms step_avg:97.45ms -step:1252/1695 train_time:121984ms step_avg:97.43ms -step:1253/1695 train_time:122082ms step_avg:97.43ms -step:1254/1695 train_time:122178ms step_avg:97.43ms -step:1255/1695 train_time:122274ms step_avg:97.43ms -step:1256/1695 train_time:122371ms step_avg:97.43ms -step:1257/1695 train_time:122467ms step_avg:97.43ms -step:1258/1695 train_time:122564ms step_avg:97.43ms -step:1259/1695 train_time:122660ms step_avg:97.43ms -step:1260/1695 train_time:122759ms step_avg:97.43ms -step:1261/1695 train_time:122864ms step_avg:97.43ms -step:1262/1695 train_time:122963ms step_avg:97.43ms -step:1263/1695 train_time:123061ms step_avg:97.44ms -step:1264/1695 train_time:123158ms step_avg:97.43ms -step:1265/1695 train_time:123255ms step_avg:97.43ms -step:1266/1695 train_time:123352ms step_avg:97.43ms -step:1267/1695 train_time:123449ms step_avg:97.43ms -step:1268/1695 train_time:123546ms step_avg:97.43ms -step:1269/1695 train_time:123642ms step_avg:97.43ms -step:1270/1695 train_time:123739ms step_avg:97.43ms -step:1271/1695 train_time:123838ms step_avg:97.43ms -step:1272/1695 train_time:123937ms step_avg:97.43ms -step:1273/1695 train_time:124036ms step_avg:97.44ms -step:1274/1695 train_time:124135ms step_avg:97.44ms -step:1275/1695 train_time:124233ms step_avg:97.44ms -step:1276/1695 train_time:124329ms step_avg:97.44ms -step:1277/1695 train_time:124426ms step_avg:97.44ms -step:1278/1695 train_time:124524ms step_avg:97.44ms -step:1279/1695 train_time:124619ms step_avg:97.44ms -step:1280/1695 train_time:124717ms step_avg:97.43ms -step:1281/1695 train_time:124815ms step_avg:97.44ms -step:1282/1695 train_time:124914ms step_avg:97.44ms -step:1283/1695 train_time:125012ms step_avg:97.44ms -step:1284/1695 train_time:125111ms step_avg:97.44ms -step:1285/1695 train_time:125208ms step_avg:97.44ms -step:1286/1695 train_time:125306ms step_avg:97.44ms -step:1287/1695 train_time:125403ms step_avg:97.44ms -step:1288/1695 train_time:125499ms step_avg:97.44ms -step:1289/1695 train_time:125597ms step_avg:97.44ms -step:1290/1695 train_time:125693ms step_avg:97.44ms -step:1291/1695 train_time:125793ms step_avg:97.44ms -step:1292/1695 train_time:125892ms step_avg:97.44ms -step:1293/1695 train_time:125993ms step_avg:97.44ms -step:1294/1695 train_time:126092ms step_avg:97.44ms -step:1295/1695 train_time:126191ms step_avg:97.44ms -step:1296/1695 train_time:126289ms step_avg:97.45ms -step:1297/1695 train_time:126387ms step_avg:97.45ms -step:1298/1695 train_time:126485ms step_avg:97.45ms -step:1299/1695 train_time:126581ms step_avg:97.45ms -step:1300/1695 train_time:126679ms step_avg:97.45ms -step:1301/1695 train_time:126776ms step_avg:97.44ms -step:1302/1695 train_time:126874ms step_avg:97.45ms -step:1303/1695 train_time:126975ms step_avg:97.45ms -step:1304/1695 train_time:127073ms step_avg:97.45ms -step:1305/1695 train_time:127173ms step_avg:97.45ms -step:1306/1695 train_time:127270ms step_avg:97.45ms -step:1307/1695 train_time:127369ms step_avg:97.45ms -step:1308/1695 train_time:127467ms step_avg:97.45ms -step:1309/1695 train_time:127564ms step_avg:97.45ms -step:1310/1695 train_time:127661ms step_avg:97.45ms -step:1311/1695 train_time:127759ms step_avg:97.45ms -step:1312/1695 train_time:127855ms step_avg:97.45ms -step:1313/1695 train_time:127953ms step_avg:97.45ms -step:1314/1695 train_time:128051ms step_avg:97.45ms -step:1315/1695 train_time:128149ms step_avg:97.45ms -step:1316/1695 train_time:128246ms step_avg:97.45ms -step:1317/1695 train_time:128343ms step_avg:97.45ms -step:1318/1695 train_time:128440ms step_avg:97.45ms -step:1319/1695 train_time:128538ms step_avg:97.45ms -step:1320/1695 train_time:128636ms step_avg:97.45ms -step:1321/1695 train_time:128734ms step_avg:97.45ms -step:1322/1695 train_time:128833ms step_avg:97.45ms -step:1323/1695 train_time:128932ms step_avg:97.45ms -step:1324/1695 train_time:129029ms step_avg:97.45ms -step:1325/1695 train_time:129127ms step_avg:97.45ms -step:1326/1695 train_time:129225ms step_avg:97.45ms -step:1327/1695 train_time:129322ms step_avg:97.45ms -step:1328/1695 train_time:129419ms step_avg:97.45ms -step:1329/1695 train_time:129517ms step_avg:97.45ms -step:1330/1695 train_time:129615ms step_avg:97.45ms -step:1331/1695 train_time:129713ms step_avg:97.46ms -step:1332/1695 train_time:129811ms step_avg:97.46ms -step:1333/1695 train_time:129910ms step_avg:97.46ms -step:1334/1695 train_time:130007ms step_avg:97.46ms -step:1335/1695 train_time:130105ms step_avg:97.46ms -step:1336/1695 train_time:130204ms step_avg:97.46ms -step:1337/1695 train_time:130300ms step_avg:97.46ms -step:1338/1695 train_time:130398ms step_avg:97.46ms -step:1339/1695 train_time:130495ms step_avg:97.46ms -step:1340/1695 train_time:130593ms step_avg:97.46ms -step:1341/1695 train_time:130692ms step_avg:97.46ms -step:1342/1695 train_time:130790ms step_avg:97.46ms -step:1343/1695 train_time:130888ms step_avg:97.46ms -step:1344/1695 train_time:130986ms step_avg:97.46ms -step:1345/1695 train_time:131084ms step_avg:97.46ms -step:1346/1695 train_time:131182ms step_avg:97.46ms -step:1347/1695 train_time:131280ms step_avg:97.46ms -step:1348/1695 train_time:131377ms step_avg:97.46ms -step:1349/1695 train_time:131475ms step_avg:97.46ms -step:1350/1695 train_time:131572ms step_avg:97.46ms -step:1351/1695 train_time:131672ms step_avg:97.46ms -step:1352/1695 train_time:131770ms step_avg:97.46ms -step:1353/1695 train_time:131869ms step_avg:97.46ms -step:1354/1695 train_time:131966ms step_avg:97.46ms -step:1355/1695 train_time:132064ms step_avg:97.46ms -step:1356/1695 train_time:132161ms step_avg:97.46ms -step:1357/1695 train_time:132258ms step_avg:97.46ms -step:1358/1695 train_time:132356ms step_avg:97.46ms -step:1359/1695 train_time:132454ms step_avg:97.46ms -step:1360/1695 train_time:132552ms step_avg:97.46ms -step:1361/1695 train_time:132650ms step_avg:97.46ms -step:1362/1695 train_time:132748ms step_avg:97.47ms -step:1363/1695 train_time:132845ms step_avg:97.47ms -step:1364/1695 train_time:132942ms step_avg:97.46ms -step:1365/1695 train_time:133038ms step_avg:97.46ms -step:1366/1695 train_time:133136ms step_avg:97.46ms -step:1367/1695 train_time:133235ms step_avg:97.46ms -step:1368/1695 train_time:133333ms step_avg:97.47ms -step:1369/1695 train_time:133430ms step_avg:97.47ms -step:1370/1695 train_time:133529ms step_avg:97.47ms -step:1371/1695 train_time:133630ms step_avg:97.47ms -step:1372/1695 train_time:133725ms step_avg:97.47ms -step:1373/1695 train_time:133823ms step_avg:97.47ms -step:1374/1695 train_time:133921ms step_avg:97.47ms -step:1375/1695 train_time:134023ms step_avg:97.47ms -step:1375/1695 val_loss:3.3505 train_time:134113ms step_avg:97.54ms -step:1376/1695 train_time:134162ms step_avg:97.50ms -step:1377/1695 train_time:134219ms step_avg:97.47ms -step:1378/1695 train_time:134319ms step_avg:97.47ms -step:1379/1695 train_time:134416ms step_avg:97.47ms -step:1380/1695 train_time:134512ms step_avg:97.47ms -step:1381/1695 train_time:134884ms step_avg:97.67ms -step:1382/1695 train_time:135043ms step_avg:97.72ms -step:1383/1695 train_time:135139ms step_avg:97.71ms -step:1384/1695 train_time:135235ms step_avg:97.71ms -step:1385/1695 train_time:135332ms step_avg:97.71ms -step:1386/1695 train_time:135428ms step_avg:97.71ms -step:1387/1695 train_time:135526ms step_avg:97.71ms -step:1388/1695 train_time:135621ms step_avg:97.71ms -step:1389/1695 train_time:135716ms step_avg:97.71ms -step:1390/1695 train_time:135813ms step_avg:97.71ms -step:1391/1695 train_time:135916ms step_avg:97.71ms -step:1392/1695 train_time:136020ms step_avg:97.72ms -step:1393/1695 train_time:136121ms step_avg:97.72ms -step:1394/1695 train_time:136218ms step_avg:97.72ms -step:1395/1695 train_time:136315ms step_avg:97.72ms -step:1396/1695 train_time:136412ms step_avg:97.72ms -step:1397/1695 train_time:136509ms step_avg:97.72ms -step:1398/1695 train_time:136607ms step_avg:97.72ms -step:1399/1695 train_time:136704ms step_avg:97.72ms -step:1400/1695 train_time:136801ms step_avg:97.71ms -step:1401/1695 train_time:136899ms step_avg:97.72ms -step:1402/1695 train_time:136997ms step_avg:97.72ms -step:1403/1695 train_time:137096ms step_avg:97.72ms -step:1404/1695 train_time:137194ms step_avg:97.72ms -step:1405/1695 train_time:137293ms step_avg:97.72ms -step:1406/1695 train_time:137390ms step_avg:97.72ms -step:1407/1695 train_time:137487ms step_avg:97.72ms -step:1408/1695 train_time:137584ms step_avg:97.72ms -step:1409/1695 train_time:137681ms step_avg:97.72ms -step:1410/1695 train_time:137778ms step_avg:97.72ms -step:1411/1695 train_time:137876ms step_avg:97.72ms -step:1412/1695 train_time:137975ms step_avg:97.72ms -step:1413/1695 train_time:138075ms step_avg:97.72ms -step:1414/1695 train_time:138173ms step_avg:97.72ms -step:1415/1695 train_time:138272ms step_avg:97.72ms -step:1416/1695 train_time:138368ms step_avg:97.72ms -step:1417/1695 train_time:138466ms step_avg:97.72ms -step:1418/1695 train_time:138563ms step_avg:97.72ms -step:1419/1695 train_time:138661ms step_avg:97.72ms -step:1420/1695 train_time:138757ms step_avg:97.72ms -step:1421/1695 train_time:138854ms step_avg:97.72ms -step:1422/1695 train_time:138953ms step_avg:97.72ms -step:1423/1695 train_time:139053ms step_avg:97.72ms -step:1424/1695 train_time:139153ms step_avg:97.72ms -step:1425/1695 train_time:139252ms step_avg:97.72ms -step:1426/1695 train_time:139350ms step_avg:97.72ms -step:1427/1695 train_time:139448ms step_avg:97.72ms -step:1428/1695 train_time:139546ms step_avg:97.72ms -step:1429/1695 train_time:139644ms step_avg:97.72ms -step:1430/1695 train_time:139742ms step_avg:97.72ms -step:1431/1695 train_time:139839ms step_avg:97.72ms -step:1432/1695 train_time:139938ms step_avg:97.72ms -step:1433/1695 train_time:140036ms step_avg:97.72ms -step:1434/1695 train_time:140134ms step_avg:97.72ms -step:1435/1695 train_time:140232ms step_avg:97.72ms -step:1436/1695 train_time:140330ms step_avg:97.72ms -step:1437/1695 train_time:140428ms step_avg:97.72ms -step:1438/1695 train_time:140526ms step_avg:97.72ms -step:1439/1695 train_time:140623ms step_avg:97.72ms -step:1440/1695 train_time:140720ms step_avg:97.72ms -step:1441/1695 train_time:140818ms step_avg:97.72ms -step:1442/1695 train_time:140917ms step_avg:97.72ms -step:1443/1695 train_time:141014ms step_avg:97.72ms -step:1444/1695 train_time:141113ms step_avg:97.72ms -step:1445/1695 train_time:141213ms step_avg:97.73ms -step:1446/1695 train_time:141312ms step_avg:97.73ms -step:1447/1695 train_time:141411ms step_avg:97.73ms -step:1448/1695 train_time:141508ms step_avg:97.73ms -step:1449/1695 train_time:141607ms step_avg:97.73ms -step:1450/1695 train_time:141706ms step_avg:97.73ms -step:1451/1695 train_time:141804ms step_avg:97.73ms -step:1452/1695 train_time:141902ms step_avg:97.73ms -step:1453/1695 train_time:141999ms step_avg:97.73ms -step:1454/1695 train_time:142097ms step_avg:97.73ms -step:1455/1695 train_time:142194ms step_avg:97.73ms -step:1456/1695 train_time:142292ms step_avg:97.73ms -step:1457/1695 train_time:142390ms step_avg:97.73ms -step:1458/1695 train_time:142488ms step_avg:97.73ms -step:1459/1695 train_time:142588ms step_avg:97.73ms -step:1460/1695 train_time:142685ms step_avg:97.73ms -step:1461/1695 train_time:142784ms step_avg:97.73ms -step:1462/1695 train_time:142881ms step_avg:97.73ms -step:1463/1695 train_time:142978ms step_avg:97.73ms -step:1464/1695 train_time:143076ms step_avg:97.73ms -step:1465/1695 train_time:143173ms step_avg:97.73ms -step:1466/1695 train_time:143272ms step_avg:97.73ms -step:1467/1695 train_time:143369ms step_avg:97.73ms -step:1468/1695 train_time:143467ms step_avg:97.73ms -step:1469/1695 train_time:143564ms step_avg:97.73ms -step:1470/1695 train_time:143661ms step_avg:97.73ms -step:1471/1695 train_time:143759ms step_avg:97.73ms -step:1472/1695 train_time:143856ms step_avg:97.73ms -step:1473/1695 train_time:143956ms step_avg:97.73ms -step:1474/1695 train_time:144053ms step_avg:97.73ms -step:1475/1695 train_time:144149ms step_avg:97.73ms -step:1476/1695 train_time:144248ms step_avg:97.73ms -step:1477/1695 train_time:144346ms step_avg:97.73ms -step:1478/1695 train_time:144445ms step_avg:97.73ms -step:1479/1695 train_time:144541ms step_avg:97.73ms -step:1480/1695 train_time:144639ms step_avg:97.73ms -step:1481/1695 train_time:144737ms step_avg:97.73ms -step:1482/1695 train_time:144835ms step_avg:97.73ms -step:1483/1695 train_time:144933ms step_avg:97.73ms -step:1484/1695 train_time:145031ms step_avg:97.73ms -step:1485/1695 train_time:145128ms step_avg:97.73ms -step:1486/1695 train_time:145225ms step_avg:97.73ms -step:1487/1695 train_time:145323ms step_avg:97.73ms -step:1488/1695 train_time:145421ms step_avg:97.73ms -step:1489/1695 train_time:145519ms step_avg:97.73ms -step:1490/1695 train_time:145616ms step_avg:97.73ms -step:1491/1695 train_time:145714ms step_avg:97.73ms -step:1492/1695 train_time:145813ms step_avg:97.73ms -step:1493/1695 train_time:145910ms step_avg:97.73ms -step:1494/1695 train_time:146008ms step_avg:97.73ms -step:1495/1695 train_time:146105ms step_avg:97.73ms -step:1496/1695 train_time:146203ms step_avg:97.73ms -step:1497/1695 train_time:146299ms step_avg:97.73ms -step:1498/1695 train_time:146396ms step_avg:97.73ms -step:1499/1695 train_time:146495ms step_avg:97.73ms -step:1500/1695 train_time:146593ms step_avg:97.73ms -step:1500/1695 val_loss:3.3176 train_time:146690ms step_avg:97.79ms -step:1501/1695 train_time:146740ms step_avg:97.76ms -step:1502/1695 train_time:146800ms step_avg:97.74ms -step:1503/1695 train_time:146898ms step_avg:97.74ms -step:1504/1695 train_time:146995ms step_avg:97.74ms -step:1505/1695 train_time:147094ms step_avg:97.74ms -step:1506/1695 train_time:147190ms step_avg:97.74ms -step:1507/1695 train_time:147287ms step_avg:97.74ms -step:1508/1695 train_time:147383ms step_avg:97.73ms -step:1509/1695 train_time:147481ms step_avg:97.73ms -step:1510/1695 train_time:147577ms step_avg:97.73ms -step:1511/1695 train_time:147677ms step_avg:97.73ms -step:1512/1695 train_time:147778ms step_avg:97.74ms -step:1513/1695 train_time:147878ms step_avg:97.74ms -step:1514/1695 train_time:147975ms step_avg:97.74ms -step:1515/1695 train_time:148073ms step_avg:97.74ms -step:1516/1695 train_time:148170ms step_avg:97.74ms -step:1517/1695 train_time:148269ms step_avg:97.74ms -step:1518/1695 train_time:148367ms step_avg:97.74ms -step:1519/1695 train_time:148463ms step_avg:97.74ms -step:1520/1695 train_time:148560ms step_avg:97.74ms -step:1521/1695 train_time:148658ms step_avg:97.74ms -step:1522/1695 train_time:148758ms step_avg:97.74ms -step:1523/1695 train_time:148857ms step_avg:97.74ms -step:1524/1695 train_time:148956ms step_avg:97.74ms -step:1525/1695 train_time:149055ms step_avg:97.74ms -step:1526/1695 train_time:149152ms step_avg:97.74ms -step:1527/1695 train_time:149250ms step_avg:97.74ms -step:1528/1695 train_time:149347ms step_avg:97.74ms -step:1529/1695 train_time:149444ms step_avg:97.74ms -step:1530/1695 train_time:149541ms step_avg:97.74ms -step:1531/1695 train_time:149638ms step_avg:97.74ms -step:1532/1695 train_time:149735ms step_avg:97.74ms -step:1533/1695 train_time:149835ms step_avg:97.74ms -step:1534/1695 train_time:149933ms step_avg:97.74ms -step:1535/1695 train_time:150032ms step_avg:97.74ms -step:1536/1695 train_time:150129ms step_avg:97.74ms -step:1537/1695 train_time:150227ms step_avg:97.74ms -step:1538/1695 train_time:150324ms step_avg:97.74ms -step:1539/1695 train_time:150421ms step_avg:97.74ms -step:1540/1695 train_time:150519ms step_avg:97.74ms -step:1541/1695 train_time:150616ms step_avg:97.74ms -step:1542/1695 train_time:150715ms step_avg:97.74ms -step:1543/1695 train_time:150813ms step_avg:97.74ms -step:1544/1695 train_time:150912ms step_avg:97.74ms -step:1545/1695 train_time:151010ms step_avg:97.74ms -step:1546/1695 train_time:151108ms step_avg:97.74ms -step:1547/1695 train_time:151206ms step_avg:97.74ms -step:1548/1695 train_time:151304ms step_avg:97.74ms -step:1549/1695 train_time:151401ms step_avg:97.74ms -step:1550/1695 train_time:151498ms step_avg:97.74ms -step:1551/1695 train_time:151596ms step_avg:97.74ms -step:1552/1695 train_time:151941ms step_avg:97.90ms -step:1553/1695 train_time:152117ms step_avg:97.95ms -step:1554/1695 train_time:152212ms step_avg:97.95ms -step:1555/1695 train_time:152308ms step_avg:97.95ms -step:1556/1695 train_time:152403ms step_avg:97.95ms -step:1557/1695 train_time:152499ms step_avg:97.94ms -step:1558/1695 train_time:152596ms step_avg:97.94ms -step:1559/1695 train_time:152693ms step_avg:97.94ms -step:1560/1695 train_time:152790ms step_avg:97.94ms -step:1561/1695 train_time:152886ms step_avg:97.94ms -step:1562/1695 train_time:152987ms step_avg:97.94ms -step:1563/1695 train_time:153089ms step_avg:97.95ms -step:1564/1695 train_time:153189ms step_avg:97.95ms -step:1565/1695 train_time:153285ms step_avg:97.95ms -step:1566/1695 train_time:153382ms step_avg:97.95ms -step:1567/1695 train_time:153479ms step_avg:97.94ms -step:1568/1695 train_time:153576ms step_avg:97.94ms -step:1569/1695 train_time:153673ms step_avg:97.94ms -step:1570/1695 train_time:153769ms step_avg:97.94ms -step:1571/1695 train_time:153866ms step_avg:97.94ms -step:1572/1695 train_time:153964ms step_avg:97.94ms -step:1573/1695 train_time:154064ms step_avg:97.94ms -step:1574/1695 train_time:154164ms step_avg:97.94ms -step:1575/1695 train_time:154262ms step_avg:97.94ms -step:1576/1695 train_time:154360ms step_avg:97.94ms -step:1577/1695 train_time:154457ms step_avg:97.94ms -step:1578/1695 train_time:154554ms step_avg:97.94ms -step:1579/1695 train_time:154651ms step_avg:97.94ms -step:1580/1695 train_time:154748ms step_avg:97.94ms -step:1581/1695 train_time:154845ms step_avg:97.94ms -step:1582/1695 train_time:154943ms step_avg:97.94ms -step:1583/1695 train_time:155042ms step_avg:97.94ms -step:1584/1695 train_time:155142ms step_avg:97.94ms -step:1585/1695 train_time:155239ms step_avg:97.94ms -step:1586/1695 train_time:155337ms step_avg:97.94ms -step:1587/1695 train_time:155435ms step_avg:97.94ms -step:1588/1695 train_time:155532ms step_avg:97.94ms -step:1589/1695 train_time:155630ms step_avg:97.94ms -step:1590/1695 train_time:155726ms step_avg:97.94ms -step:1591/1695 train_time:155823ms step_avg:97.94ms -step:1592/1695 train_time:155921ms step_avg:97.94ms -step:1593/1695 train_time:156019ms step_avg:97.94ms -step:1594/1695 train_time:156119ms step_avg:97.94ms -step:1595/1695 train_time:156218ms step_avg:97.94ms -step:1596/1695 train_time:156317ms step_avg:97.94ms -step:1597/1695 train_time:156416ms step_avg:97.94ms -step:1598/1695 train_time:156513ms step_avg:97.94ms -step:1599/1695 train_time:156610ms step_avg:97.94ms -step:1600/1695 train_time:156707ms step_avg:97.94ms -step:1601/1695 train_time:156804ms step_avg:97.94ms -step:1602/1695 train_time:156901ms step_avg:97.94ms -step:1603/1695 train_time:156999ms step_avg:97.94ms -step:1604/1695 train_time:157097ms step_avg:97.94ms -step:1605/1695 train_time:157196ms step_avg:97.94ms -step:1606/1695 train_time:157295ms step_avg:97.94ms -step:1607/1695 train_time:157393ms step_avg:97.94ms -step:1608/1695 train_time:157491ms step_avg:97.94ms -step:1609/1695 train_time:157590ms step_avg:97.94ms -step:1610/1695 train_time:157687ms step_avg:97.94ms -step:1611/1695 train_time:157784ms step_avg:97.94ms -step:1612/1695 train_time:157881ms step_avg:97.94ms -step:1613/1695 train_time:157978ms step_avg:97.94ms -step:1614/1695 train_time:158077ms step_avg:97.94ms -step:1615/1695 train_time:158176ms step_avg:97.94ms -step:1616/1695 train_time:158274ms step_avg:97.94ms -step:1617/1695 train_time:158373ms step_avg:97.94ms -step:1618/1695 train_time:158471ms step_avg:97.94ms -step:1619/1695 train_time:158569ms step_avg:97.94ms -step:1620/1695 train_time:158666ms step_avg:97.94ms -step:1621/1695 train_time:158765ms step_avg:97.94ms -step:1622/1695 train_time:158861ms step_avg:97.94ms -step:1623/1695 train_time:158958ms step_avg:97.94ms -step:1624/1695 train_time:159057ms step_avg:97.94ms -step:1625/1695 train_time:159155ms step_avg:97.94ms -step:1625/1695 val_loss:3.2898 train_time:159251ms step_avg:98.00ms -step:1626/1695 train_time:159278ms step_avg:97.96ms -step:1627/1695 train_time:159359ms step_avg:97.95ms -step:1628/1695 train_time:159459ms step_avg:97.95ms -step:1629/1695 train_time:159560ms step_avg:97.95ms -step:1630/1695 train_time:159658ms step_avg:97.95ms -step:1631/1695 train_time:159755ms step_avg:97.95ms -step:1632/1695 train_time:159852ms step_avg:97.95ms -step:1633/1695 train_time:159950ms step_avg:97.95ms -step:1634/1695 train_time:160046ms step_avg:97.95ms -step:1635/1695 train_time:160142ms step_avg:97.95ms -step:1636/1695 train_time:160242ms step_avg:97.95ms -step:1637/1695 train_time:160342ms step_avg:97.95ms -step:1638/1695 train_time:160441ms step_avg:97.95ms -step:1639/1695 train_time:160539ms step_avg:97.95ms -step:1640/1695 train_time:160639ms step_avg:97.95ms -step:1641/1695 train_time:160737ms step_avg:97.95ms -step:1642/1695 train_time:160835ms step_avg:97.95ms -step:1643/1695 train_time:160933ms step_avg:97.95ms -step:1644/1695 train_time:161031ms step_avg:97.95ms -step:1645/1695 train_time:161128ms step_avg:97.95ms -step:1646/1695 train_time:161225ms step_avg:97.95ms -step:1647/1695 train_time:161323ms step_avg:97.95ms -step:1648/1695 train_time:161421ms step_avg:97.95ms -step:1649/1695 train_time:161519ms step_avg:97.95ms -step:1650/1695 train_time:161618ms step_avg:97.95ms -step:1651/1695 train_time:161717ms step_avg:97.95ms -step:1652/1695 train_time:161816ms step_avg:97.95ms -step:1653/1695 train_time:161914ms step_avg:97.95ms -step:1654/1695 train_time:162010ms step_avg:97.95ms -step:1655/1695 train_time:162107ms step_avg:97.95ms -step:1656/1695 train_time:162205ms step_avg:97.95ms -step:1657/1695 train_time:162302ms step_avg:97.95ms -step:1658/1695 train_time:162399ms step_avg:97.95ms -step:1659/1695 train_time:162496ms step_avg:97.95ms -step:1660/1695 train_time:162595ms step_avg:97.95ms -step:1661/1695 train_time:162693ms step_avg:97.95ms -step:1662/1695 train_time:162791ms step_avg:97.95ms -step:1663/1695 train_time:162888ms step_avg:97.95ms -step:1664/1695 train_time:162985ms step_avg:97.95ms -step:1665/1695 train_time:163083ms step_avg:97.95ms -step:1666/1695 train_time:163181ms step_avg:97.95ms -step:1667/1695 train_time:163280ms step_avg:97.95ms -step:1668/1695 train_time:163379ms step_avg:97.95ms -step:1669/1695 train_time:163476ms step_avg:97.95ms -step:1670/1695 train_time:163574ms step_avg:97.95ms -step:1671/1695 train_time:163672ms step_avg:97.95ms -step:1672/1695 train_time:163769ms step_avg:97.95ms -step:1673/1695 train_time:163867ms step_avg:97.95ms -step:1674/1695 train_time:163964ms step_avg:97.95ms -step:1675/1695 train_time:164062ms step_avg:97.95ms -step:1676/1695 train_time:164160ms step_avg:97.95ms -step:1677/1695 train_time:164258ms step_avg:97.95ms -step:1678/1695 train_time:164359ms step_avg:97.95ms -step:1679/1695 train_time:164455ms step_avg:97.95ms -step:1680/1695 train_time:164553ms step_avg:97.95ms -step:1681/1695 train_time:164651ms step_avg:97.95ms -step:1682/1695 train_time:164750ms step_avg:97.95ms -step:1683/1695 train_time:164849ms step_avg:97.95ms -step:1684/1695 train_time:164947ms step_avg:97.95ms -step:1685/1695 train_time:165044ms step_avg:97.95ms -step:1686/1695 train_time:165141ms step_avg:97.95ms -step:1687/1695 train_time:165239ms step_avg:97.95ms -step:1688/1695 train_time:165336ms step_avg:97.95ms -step:1689/1695 train_time:165435ms step_avg:97.95ms -step:1690/1695 train_time:165533ms step_avg:97.95ms -step:1691/1695 train_time:165631ms step_avg:97.95ms -step:1692/1695 train_time:165728ms step_avg:97.95ms -step:1693/1695 train_time:165825ms step_avg:97.95ms -step:1694/1695 train_time:165922ms step_avg:97.95ms -step:1695/1695 train_time:166021ms step_avg:97.95ms -step:1695/1695 val_loss:3.2782 train_time:166117ms step_avg:98.00ms -peak memory allocated: 34001 MiB reserved: 49716 MiB diff --git a/records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt b/records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt new file mode 100644 index 000000000..dbcf68147 --- /dev/null +++ b/records/090325_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:58:00 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 29C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 32C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 29C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 28C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 51945 C /usr/bin/python 0MiB | +| 0 N/A N/A 51946 C /usr/bin/python 0MiB | +| 0 N/A N/A 51947 C /usr/bin/python 0MiB | +| 0 N/A N/A 51948 C /usr/bin/python 0MiB | +| 0 N/A N/A 51949 C /usr/bin/python 0MiB | +| 0 N/A N/A 51950 C /usr/bin/python 0MiB | +| 0 N/A N/A 51951 C /usr/bin/python 0MiB | +| 0 N/A N/A 51952 C /usr/bin/python 0MiB | +| 1 N/A N/A 51946 C /usr/bin/python 0MiB | +| 2 N/A N/A 51947 C /usr/bin/python 0MiB | +| 3 N/A N/A 51948 C /usr/bin/python 0MiB | +| 4 N/A N/A 51949 C /usr/bin/python 0MiB | +| 5 N/A N/A 51950 C /usr/bin/python 0MiB | +| 6 N/A N/A 51951 C /usr/bin/python 0MiB | +| 7 N/A N/A 51952 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:473ms step_avg:472.96ms +step:2/1670 train_time:494ms step_avg:246.81ms +step:3/1670 train_time:568ms step_avg:189.19ms +step:4/1670 train_time:661ms step_avg:165.29ms +step:5/1670 train_time:755ms step_avg:151.05ms +step:6/1670 train_time:849ms step_avg:141.57ms +step:7/1670 train_time:944ms step_avg:134.90ms +step:8/1670 train_time:1039ms step_avg:129.87ms +step:9/1670 train_time:1134ms step_avg:125.95ms +step:10/1670 train_time:1228ms step_avg:122.84ms +step:11/1670 train_time:1323ms step_avg:120.30ms +step:12/1670 train_time:1421ms step_avg:118.43ms +step:13/1670 train_time:1521ms step_avg:117.02ms +step:14/1670 train_time:1618ms step_avg:115.58ms +step:15/1670 train_time:1713ms step_avg:114.21ms +step:16/1670 train_time:1808ms step_avg:113.03ms +step:17/1670 train_time:1903ms step_avg:111.96ms +step:18/1670 train_time:1998ms step_avg:111.01ms +step:19/1670 train_time:2093ms step_avg:110.16ms +step:20/1670 train_time:2188ms step_avg:109.41ms +step:21/1670 train_time:2284ms step_avg:108.74ms +step:22/1670 train_time:2381ms step_avg:108.21ms +step:23/1670 train_time:2477ms step_avg:107.71ms +step:24/1670 train_time:2574ms step_avg:107.26ms +step:25/1670 train_time:2671ms step_avg:106.83ms +step:26/1670 train_time:2766ms step_avg:106.40ms +step:27/1670 train_time:2862ms step_avg:106.01ms +step:28/1670 train_time:2957ms step_avg:105.61ms +step:29/1670 train_time:3052ms step_avg:105.24ms +step:30/1670 train_time:3146ms step_avg:104.88ms +step:31/1670 train_time:3241ms step_avg:104.56ms +step:32/1670 train_time:3337ms step_avg:104.28ms +step:33/1670 train_time:3433ms step_avg:104.04ms +step:34/1670 train_time:3530ms step_avg:103.84ms +step:35/1670 train_time:3628ms step_avg:103.66ms +step:36/1670 train_time:3724ms step_avg:103.45ms +step:37/1670 train_time:3819ms step_avg:103.23ms +step:38/1670 train_time:3914ms step_avg:103.01ms +step:39/1670 train_time:4010ms step_avg:102.81ms +step:40/1670 train_time:4105ms step_avg:102.63ms +step:41/1670 train_time:4200ms step_avg:102.45ms +step:42/1670 train_time:4296ms step_avg:102.28ms +step:43/1670 train_time:4391ms step_avg:102.11ms +step:44/1670 train_time:4487ms step_avg:101.98ms +step:45/1670 train_time:4583ms step_avg:101.84ms +step:46/1670 train_time:4679ms step_avg:101.72ms +step:47/1670 train_time:4775ms step_avg:101.59ms +step:48/1670 train_time:4871ms step_avg:101.48ms +step:49/1670 train_time:4967ms step_avg:101.36ms +step:50/1670 train_time:5062ms step_avg:101.24ms +step:51/1670 train_time:5157ms step_avg:101.12ms +step:52/1670 train_time:5252ms step_avg:101.00ms +step:53/1670 train_time:5347ms step_avg:100.89ms +step:54/1670 train_time:5443ms step_avg:100.80ms +step:55/1670 train_time:5539ms step_avg:100.71ms +step:56/1670 train_time:5635ms step_avg:100.63ms +step:57/1670 train_time:5731ms step_avg:100.54ms +step:58/1670 train_time:5827ms step_avg:100.47ms +step:59/1670 train_time:5923ms step_avg:100.38ms +step:60/1670 train_time:6018ms step_avg:100.31ms +step:61/1670 train_time:6114ms step_avg:100.22ms +step:62/1670 train_time:6209ms step_avg:100.14ms +step:63/1670 train_time:6305ms step_avg:100.07ms +step:64/1670 train_time:6399ms step_avg:99.99ms +step:65/1670 train_time:6494ms step_avg:99.91ms +step:66/1670 train_time:6590ms step_avg:99.85ms +step:67/1670 train_time:6685ms step_avg:99.78ms +step:68/1670 train_time:6781ms step_avg:99.72ms +step:69/1670 train_time:6876ms step_avg:99.66ms +step:70/1670 train_time:6972ms step_avg:99.60ms +step:71/1670 train_time:7067ms step_avg:99.54ms +step:72/1670 train_time:7163ms step_avg:99.48ms +step:73/1670 train_time:7258ms step_avg:99.43ms +step:74/1670 train_time:7354ms step_avg:99.37ms +step:75/1670 train_time:7449ms step_avg:99.32ms +step:76/1670 train_time:7546ms step_avg:99.29ms +step:77/1670 train_time:7642ms step_avg:99.25ms +step:78/1670 train_time:7737ms step_avg:99.19ms +step:79/1670 train_time:7832ms step_avg:99.15ms +step:80/1670 train_time:7928ms step_avg:99.10ms +step:81/1670 train_time:8025ms step_avg:99.08ms +step:82/1670 train_time:8120ms step_avg:99.02ms +step:83/1670 train_time:8215ms step_avg:98.98ms +step:84/1670 train_time:8311ms step_avg:98.94ms +step:85/1670 train_time:8406ms step_avg:98.90ms +step:86/1670 train_time:8502ms step_avg:98.86ms +step:87/1670 train_time:8598ms step_avg:98.82ms +step:88/1670 train_time:8693ms step_avg:98.79ms +step:89/1670 train_time:8789ms step_avg:98.75ms +step:90/1670 train_time:8885ms step_avg:98.72ms +step:91/1670 train_time:8980ms step_avg:98.68ms +step:92/1670 train_time:9075ms step_avg:98.65ms +step:93/1670 train_time:9171ms step_avg:98.62ms +step:94/1670 train_time:9267ms step_avg:98.59ms +step:95/1670 train_time:9362ms step_avg:98.55ms +step:96/1670 train_time:9458ms step_avg:98.52ms +step:97/1670 train_time:9553ms step_avg:98.48ms +step:98/1670 train_time:9648ms step_avg:98.45ms +step:99/1670 train_time:9745ms step_avg:98.43ms +step:100/1670 train_time:9840ms step_avg:98.40ms +step:101/1670 train_time:9936ms step_avg:98.37ms +step:102/1670 train_time:10031ms step_avg:98.34ms +step:103/1670 train_time:10127ms step_avg:98.32ms +step:104/1670 train_time:10222ms step_avg:98.29ms +step:105/1670 train_time:10317ms step_avg:98.26ms +step:106/1670 train_time:10413ms step_avg:98.23ms +step:107/1670 train_time:10508ms step_avg:98.21ms +step:108/1670 train_time:10605ms step_avg:98.19ms +step:109/1670 train_time:10701ms step_avg:98.17ms +step:110/1670 train_time:10797ms step_avg:98.15ms +step:111/1670 train_time:10893ms step_avg:98.13ms +step:112/1670 train_time:10989ms step_avg:98.11ms +step:113/1670 train_time:11085ms step_avg:98.09ms +step:114/1670 train_time:11180ms step_avg:98.07ms +step:115/1670 train_time:11275ms step_avg:98.04ms +step:116/1670 train_time:11370ms step_avg:98.02ms +step:117/1670 train_time:11466ms step_avg:98.00ms +step:118/1670 train_time:11562ms step_avg:97.98ms +step:119/1670 train_time:11658ms step_avg:97.96ms +step:120/1670 train_time:11754ms step_avg:97.95ms +step:121/1670 train_time:11850ms step_avg:97.93ms +step:122/1670 train_time:11946ms step_avg:97.92ms +step:123/1670 train_time:12042ms step_avg:97.90ms +step:124/1670 train_time:12137ms step_avg:97.88ms +step:125/1670 train_time:12233ms step_avg:97.86ms +step:125/1670 val_loss:4.2975 train_time:12327ms step_avg:98.62ms +step:126/1670 train_time:12350ms step_avg:98.02ms +step:127/1670 train_time:12434ms step_avg:97.90ms +step:128/1670 train_time:12538ms step_avg:97.95ms +step:129/1670 train_time:12635ms step_avg:97.95ms +step:130/1670 train_time:12730ms step_avg:97.93ms +step:131/1670 train_time:12825ms step_avg:97.90ms +step:132/1670 train_time:12919ms step_avg:97.87ms +step:133/1670 train_time:13014ms step_avg:97.85ms +step:134/1670 train_time:13109ms step_avg:97.83ms +step:135/1670 train_time:13204ms step_avg:97.81ms +step:136/1670 train_time:13298ms step_avg:97.78ms +step:137/1670 train_time:13394ms step_avg:97.77ms +step:138/1670 train_time:13493ms step_avg:97.77ms +step:139/1670 train_time:13591ms step_avg:97.78ms +step:140/1670 train_time:13687ms step_avg:97.76ms +step:141/1670 train_time:13783ms step_avg:97.75ms +step:142/1670 train_time:13878ms step_avg:97.73ms +step:143/1670 train_time:13973ms step_avg:97.71ms +step:144/1670 train_time:14068ms step_avg:97.70ms +step:145/1670 train_time:14163ms step_avg:97.68ms +step:146/1670 train_time:14258ms step_avg:97.66ms +step:147/1670 train_time:14353ms step_avg:97.64ms +step:148/1670 train_time:14451ms step_avg:97.64ms +step:149/1670 train_time:14548ms step_avg:97.64ms +step:150/1670 train_time:14644ms step_avg:97.62ms +step:151/1670 train_time:14739ms step_avg:97.61ms +step:152/1670 train_time:14835ms step_avg:97.60ms +step:153/1670 train_time:14930ms step_avg:97.58ms +step:154/1670 train_time:15025ms step_avg:97.56ms +step:155/1670 train_time:15120ms step_avg:97.55ms +step:156/1670 train_time:15215ms step_avg:97.53ms +step:157/1670 train_time:15311ms step_avg:97.52ms +step:158/1670 train_time:15407ms step_avg:97.51ms +step:159/1670 train_time:15502ms step_avg:97.50ms +step:160/1670 train_time:15598ms step_avg:97.49ms +step:161/1670 train_time:15695ms step_avg:97.48ms +step:162/1670 train_time:15791ms step_avg:97.47ms +step:163/1670 train_time:15887ms step_avg:97.46ms +step:164/1670 train_time:15983ms step_avg:97.46ms +step:165/1670 train_time:16077ms step_avg:97.44ms +step:166/1670 train_time:16172ms step_avg:97.42ms +step:167/1670 train_time:16268ms step_avg:97.41ms +step:168/1670 train_time:16363ms step_avg:97.40ms +step:169/1670 train_time:16459ms step_avg:97.39ms +step:170/1670 train_time:16554ms step_avg:97.38ms +step:171/1670 train_time:16650ms step_avg:97.37ms +step:172/1670 train_time:16746ms step_avg:97.36ms +step:173/1670 train_time:16842ms step_avg:97.35ms +step:174/1670 train_time:16937ms step_avg:97.34ms +step:175/1670 train_time:17033ms step_avg:97.33ms +step:176/1670 train_time:17128ms step_avg:97.32ms +step:177/1670 train_time:17224ms step_avg:97.31ms +step:178/1670 train_time:17320ms step_avg:97.30ms +step:179/1670 train_time:17414ms step_avg:97.29ms +step:180/1670 train_time:17511ms step_avg:97.29ms +step:181/1670 train_time:17608ms step_avg:97.28ms +step:182/1670 train_time:17704ms step_avg:97.27ms +step:183/1670 train_time:17799ms step_avg:97.26ms +step:184/1670 train_time:17895ms step_avg:97.26ms +step:185/1670 train_time:17991ms step_avg:97.25ms +step:186/1670 train_time:18087ms step_avg:97.24ms +step:187/1670 train_time:18182ms step_avg:97.23ms +step:188/1670 train_time:18277ms step_avg:97.22ms +step:189/1670 train_time:18372ms step_avg:97.21ms +step:190/1670 train_time:18468ms step_avg:97.20ms +step:191/1670 train_time:18564ms step_avg:97.19ms +step:192/1670 train_time:18660ms step_avg:97.18ms +step:193/1670 train_time:18755ms step_avg:97.18ms +step:194/1670 train_time:18851ms step_avg:97.17ms +step:195/1670 train_time:18947ms step_avg:97.16ms +step:196/1670 train_time:19042ms step_avg:97.15ms +step:197/1670 train_time:19137ms step_avg:97.14ms +step:198/1670 train_time:19232ms step_avg:97.13ms +step:199/1670 train_time:19327ms step_avg:97.12ms +step:200/1670 train_time:19423ms step_avg:97.11ms +step:201/1670 train_time:19518ms step_avg:97.10ms +step:202/1670 train_time:19613ms step_avg:97.10ms +step:203/1670 train_time:19709ms step_avg:97.09ms +step:204/1670 train_time:19805ms step_avg:97.08ms +step:205/1670 train_time:19900ms step_avg:97.07ms +step:206/1670 train_time:19995ms step_avg:97.06ms +step:207/1670 train_time:20091ms step_avg:97.06ms +step:208/1670 train_time:20186ms step_avg:97.05ms +step:209/1670 train_time:20282ms step_avg:97.04ms +step:210/1670 train_time:20377ms step_avg:97.03ms +step:211/1670 train_time:20472ms step_avg:97.03ms +step:212/1670 train_time:20568ms step_avg:97.02ms +step:213/1670 train_time:20870ms step_avg:97.98ms +step:214/1670 train_time:20943ms step_avg:97.86ms +step:215/1670 train_time:21036ms step_avg:97.84ms +step:216/1670 train_time:21131ms step_avg:97.83ms +step:217/1670 train_time:21226ms step_avg:97.81ms +step:218/1670 train_time:21321ms step_avg:97.80ms +step:219/1670 train_time:21415ms step_avg:97.79ms +step:220/1670 train_time:21510ms step_avg:97.77ms +step:221/1670 train_time:21605ms step_avg:97.76ms +step:222/1670 train_time:21699ms step_avg:97.74ms +step:223/1670 train_time:21798ms step_avg:97.75ms +step:224/1670 train_time:21896ms step_avg:97.75ms +step:225/1670 train_time:21992ms step_avg:97.74ms +step:226/1670 train_time:22087ms step_avg:97.73ms +step:227/1670 train_time:22182ms step_avg:97.72ms +step:228/1670 train_time:22276ms step_avg:97.70ms +step:229/1670 train_time:22371ms step_avg:97.69ms +step:230/1670 train_time:22466ms step_avg:97.68ms +step:231/1670 train_time:22560ms step_avg:97.66ms +step:232/1670 train_time:22654ms step_avg:97.65ms +step:233/1670 train_time:22751ms step_avg:97.64ms +step:234/1670 train_time:22849ms step_avg:97.65ms +step:235/1670 train_time:22946ms step_avg:97.64ms +step:236/1670 train_time:23042ms step_avg:97.63ms +step:237/1670 train_time:23137ms step_avg:97.63ms +step:238/1670 train_time:23233ms step_avg:97.62ms +step:239/1670 train_time:23328ms step_avg:97.61ms +step:240/1670 train_time:23423ms step_avg:97.60ms +step:241/1670 train_time:23518ms step_avg:97.58ms +step:242/1670 train_time:23613ms step_avg:97.57ms +step:243/1670 train_time:23709ms step_avg:97.57ms +step:244/1670 train_time:23806ms step_avg:97.56ms +step:245/1670 train_time:23902ms step_avg:97.56ms +step:246/1670 train_time:23998ms step_avg:97.55ms +step:247/1670 train_time:24094ms step_avg:97.55ms +step:248/1670 train_time:24190ms step_avg:97.54ms +step:249/1670 train_time:24286ms step_avg:97.53ms +step:250/1670 train_time:24381ms step_avg:97.53ms +step:250/1670 val_loss:3.9606 train_time:24475ms step_avg:97.90ms +step:251/1670 train_time:24495ms step_avg:97.59ms +step:252/1670 train_time:24577ms step_avg:97.53ms +step:253/1670 train_time:24676ms step_avg:97.53ms +step:254/1670 train_time:24771ms step_avg:97.52ms +step:255/1670 train_time:24866ms step_avg:97.51ms +step:256/1670 train_time:24961ms step_avg:97.50ms +step:257/1670 train_time:25055ms step_avg:97.49ms +step:258/1670 train_time:25150ms step_avg:97.48ms +step:259/1670 train_time:25244ms step_avg:97.47ms +step:260/1670 train_time:25339ms step_avg:97.46ms +step:261/1670 train_time:25434ms step_avg:97.45ms +step:262/1670 train_time:25530ms step_avg:97.44ms +step:263/1670 train_time:25628ms step_avg:97.44ms +step:264/1670 train_time:25726ms step_avg:97.45ms +step:265/1670 train_time:25822ms step_avg:97.44ms +step:266/1670 train_time:25917ms step_avg:97.43ms +step:267/1670 train_time:26012ms step_avg:97.42ms +step:268/1670 train_time:26106ms step_avg:97.41ms +step:269/1670 train_time:26201ms step_avg:97.40ms +step:270/1670 train_time:26295ms step_avg:97.39ms +step:271/1670 train_time:26390ms step_avg:97.38ms +step:272/1670 train_time:26486ms step_avg:97.37ms +step:273/1670 train_time:26583ms step_avg:97.37ms +step:274/1670 train_time:26680ms step_avg:97.37ms +step:275/1670 train_time:26775ms step_avg:97.37ms +step:276/1670 train_time:26871ms step_avg:97.36ms +step:277/1670 train_time:26966ms step_avg:97.35ms +step:278/1670 train_time:27061ms step_avg:97.34ms +step:279/1670 train_time:27156ms step_avg:97.33ms +step:280/1670 train_time:27252ms step_avg:97.33ms +step:281/1670 train_time:27347ms step_avg:97.32ms +step:282/1670 train_time:27443ms step_avg:97.31ms +step:283/1670 train_time:27539ms step_avg:97.31ms +step:284/1670 train_time:27634ms step_avg:97.30ms +step:285/1670 train_time:27730ms step_avg:97.30ms +step:286/1670 train_time:27826ms step_avg:97.29ms +step:287/1670 train_time:27922ms step_avg:97.29ms +step:288/1670 train_time:28017ms step_avg:97.28ms +step:289/1670 train_time:28112ms step_avg:97.27ms +step:290/1670 train_time:28206ms step_avg:97.26ms +step:291/1670 train_time:28302ms step_avg:97.26ms +step:292/1670 train_time:28397ms step_avg:97.25ms +step:293/1670 train_time:28492ms step_avg:97.24ms +step:294/1670 train_time:28588ms step_avg:97.24ms +step:295/1670 train_time:28686ms step_avg:97.24ms +step:296/1670 train_time:28782ms step_avg:97.24ms +step:297/1670 train_time:28878ms step_avg:97.23ms +step:298/1670 train_time:28973ms step_avg:97.22ms +step:299/1670 train_time:29068ms step_avg:97.22ms +step:300/1670 train_time:29163ms step_avg:97.21ms +step:301/1670 train_time:29258ms step_avg:97.20ms +step:302/1670 train_time:29352ms step_avg:97.19ms +step:303/1670 train_time:29448ms step_avg:97.19ms +step:304/1670 train_time:29544ms step_avg:97.18ms +step:305/1670 train_time:29640ms step_avg:97.18ms +step:306/1670 train_time:29736ms step_avg:97.18ms +step:307/1670 train_time:29833ms step_avg:97.17ms +step:308/1670 train_time:29928ms step_avg:97.17ms +step:309/1670 train_time:30024ms step_avg:97.16ms +step:310/1670 train_time:30119ms step_avg:97.16ms +step:311/1670 train_time:30214ms step_avg:97.15ms +step:312/1670 train_time:30309ms step_avg:97.14ms +step:313/1670 train_time:30404ms step_avg:97.14ms +step:314/1670 train_time:30499ms step_avg:97.13ms +step:315/1670 train_time:30595ms step_avg:97.13ms +step:316/1670 train_time:30690ms step_avg:97.12ms +step:317/1670 train_time:30787ms step_avg:97.12ms +step:318/1670 train_time:30884ms step_avg:97.12ms +step:319/1670 train_time:30979ms step_avg:97.11ms +step:320/1670 train_time:31074ms step_avg:97.11ms +step:321/1670 train_time:31170ms step_avg:97.10ms +step:322/1670 train_time:31265ms step_avg:97.10ms +step:323/1670 train_time:31361ms step_avg:97.09ms +step:324/1670 train_time:31456ms step_avg:97.09ms +step:325/1670 train_time:31551ms step_avg:97.08ms +step:326/1670 train_time:31647ms step_avg:97.08ms +step:327/1670 train_time:31743ms step_avg:97.07ms +step:328/1670 train_time:31839ms step_avg:97.07ms +step:329/1670 train_time:31934ms step_avg:97.06ms +step:330/1670 train_time:32030ms step_avg:97.06ms +step:331/1670 train_time:32125ms step_avg:97.05ms +step:332/1670 train_time:32221ms step_avg:97.05ms +step:333/1670 train_time:32317ms step_avg:97.05ms +step:334/1670 train_time:32412ms step_avg:97.04ms +step:335/1670 train_time:32507ms step_avg:97.04ms +step:336/1670 train_time:32603ms step_avg:97.03ms +step:337/1670 train_time:32698ms step_avg:97.03ms +step:338/1670 train_time:32793ms step_avg:97.02ms +step:339/1670 train_time:32888ms step_avg:97.01ms +step:340/1670 train_time:32984ms step_avg:97.01ms +step:341/1670 train_time:33080ms step_avg:97.01ms +step:342/1670 train_time:33176ms step_avg:97.01ms +step:343/1670 train_time:33271ms step_avg:97.00ms +step:344/1670 train_time:33367ms step_avg:97.00ms +step:345/1670 train_time:33463ms step_avg:96.99ms +step:346/1670 train_time:33558ms step_avg:96.99ms +step:347/1670 train_time:33654ms step_avg:96.98ms +step:348/1670 train_time:33749ms step_avg:96.98ms +step:349/1670 train_time:33845ms step_avg:96.98ms +step:350/1670 train_time:33940ms step_avg:96.97ms +step:351/1670 train_time:34036ms step_avg:96.97ms +step:352/1670 train_time:34131ms step_avg:96.96ms +step:353/1670 train_time:34228ms step_avg:96.96ms +step:354/1670 train_time:34323ms step_avg:96.96ms +step:355/1670 train_time:34419ms step_avg:96.95ms +step:356/1670 train_time:34514ms step_avg:96.95ms +step:357/1670 train_time:34609ms step_avg:96.94ms +step:358/1670 train_time:34705ms step_avg:96.94ms +step:359/1670 train_time:34801ms step_avg:96.94ms +step:360/1670 train_time:34896ms step_avg:96.93ms +step:361/1670 train_time:34991ms step_avg:96.93ms +step:362/1670 train_time:35086ms step_avg:96.92ms +step:363/1670 train_time:35182ms step_avg:96.92ms +step:364/1670 train_time:35278ms step_avg:96.92ms +step:365/1670 train_time:35372ms step_avg:96.91ms +step:366/1670 train_time:35468ms step_avg:96.91ms +step:367/1670 train_time:35563ms step_avg:96.90ms +step:368/1670 train_time:35659ms step_avg:96.90ms +step:369/1670 train_time:35755ms step_avg:96.90ms +step:370/1670 train_time:35850ms step_avg:96.89ms +step:371/1670 train_time:35946ms step_avg:96.89ms +step:372/1670 train_time:36042ms step_avg:96.89ms +step:373/1670 train_time:36137ms step_avg:96.88ms +step:374/1670 train_time:36232ms step_avg:96.88ms +step:375/1670 train_time:36327ms step_avg:96.87ms +step:375/1670 val_loss:3.8096 train_time:36422ms step_avg:97.13ms +step:376/1670 train_time:36443ms step_avg:96.92ms +step:377/1670 train_time:36526ms step_avg:96.89ms +step:378/1670 train_time:36623ms step_avg:96.89ms +step:379/1670 train_time:36718ms step_avg:96.88ms +step:380/1670 train_time:36813ms step_avg:96.88ms +step:381/1670 train_time:36908ms step_avg:96.87ms +step:382/1670 train_time:37003ms step_avg:96.87ms +step:383/1670 train_time:37097ms step_avg:96.86ms +step:384/1670 train_time:37194ms step_avg:96.86ms +step:385/1670 train_time:37288ms step_avg:96.85ms +step:386/1670 train_time:37383ms step_avg:96.85ms +step:387/1670 train_time:37481ms step_avg:96.85ms +step:388/1670 train_time:37579ms step_avg:96.85ms +step:389/1670 train_time:37675ms step_avg:96.85ms +step:390/1670 train_time:37771ms step_avg:96.85ms +step:391/1670 train_time:37866ms step_avg:96.84ms +step:392/1670 train_time:37961ms step_avg:96.84ms +step:393/1670 train_time:38056ms step_avg:96.83ms +step:394/1670 train_time:38152ms step_avg:96.83ms +step:395/1670 train_time:38248ms step_avg:96.83ms +step:396/1670 train_time:38343ms step_avg:96.83ms +step:397/1670 train_time:38439ms step_avg:96.82ms +step:398/1670 train_time:38537ms step_avg:96.83ms +step:399/1670 train_time:38633ms step_avg:96.82ms +step:400/1670 train_time:38728ms step_avg:96.82ms +step:401/1670 train_time:38823ms step_avg:96.82ms +step:402/1670 train_time:38918ms step_avg:96.81ms +step:403/1670 train_time:39014ms step_avg:96.81ms +step:404/1670 train_time:39110ms step_avg:96.81ms +step:405/1670 train_time:39206ms step_avg:96.80ms +step:406/1670 train_time:39301ms step_avg:96.80ms +step:407/1670 train_time:39397ms step_avg:96.80ms +step:408/1670 train_time:39495ms step_avg:96.80ms +step:409/1670 train_time:39592ms step_avg:96.80ms +step:410/1670 train_time:39688ms step_avg:96.80ms +step:411/1670 train_time:39783ms step_avg:96.79ms +step:412/1670 train_time:39878ms step_avg:96.79ms +step:413/1670 train_time:39973ms step_avg:96.79ms +step:414/1670 train_time:40068ms step_avg:96.78ms +step:415/1670 train_time:40163ms step_avg:96.78ms +step:416/1670 train_time:40258ms step_avg:96.77ms +step:417/1670 train_time:40354ms step_avg:96.77ms +step:418/1670 train_time:40450ms step_avg:96.77ms +step:419/1670 train_time:40546ms step_avg:96.77ms +step:420/1670 train_time:40642ms step_avg:96.77ms +step:421/1670 train_time:40738ms step_avg:96.76ms +step:422/1670 train_time:40834ms step_avg:96.76ms +step:423/1670 train_time:40929ms step_avg:96.76ms +step:424/1670 train_time:41024ms step_avg:96.76ms +step:425/1670 train_time:41351ms step_avg:97.30ms +step:426/1670 train_time:41425ms step_avg:97.24ms +step:427/1670 train_time:41519ms step_avg:97.23ms +step:428/1670 train_time:41613ms step_avg:97.23ms +step:429/1670 train_time:41708ms step_avg:97.22ms +step:430/1670 train_time:41803ms step_avg:97.22ms +step:431/1670 train_time:41897ms step_avg:97.21ms +step:432/1670 train_time:41992ms step_avg:97.20ms +step:433/1670 train_time:42086ms step_avg:97.20ms +step:434/1670 train_time:42180ms step_avg:97.19ms +step:435/1670 train_time:42281ms step_avg:97.20ms +step:436/1670 train_time:42381ms step_avg:97.20ms +step:437/1670 train_time:42479ms step_avg:97.21ms +step:438/1670 train_time:42575ms step_avg:97.20ms +step:439/1670 train_time:42671ms step_avg:97.20ms +step:440/1670 train_time:42765ms step_avg:97.19ms +step:441/1670 train_time:42859ms step_avg:97.19ms +step:442/1670 train_time:42954ms step_avg:97.18ms +step:443/1670 train_time:43048ms step_avg:97.17ms +step:444/1670 train_time:43143ms step_avg:97.17ms +step:445/1670 train_time:43239ms step_avg:97.17ms +step:446/1670 train_time:43337ms step_avg:97.17ms +step:447/1670 train_time:43435ms step_avg:97.17ms +step:448/1670 train_time:43532ms step_avg:97.17ms +step:449/1670 train_time:43628ms step_avg:97.17ms +step:450/1670 train_time:43722ms step_avg:97.16ms +step:451/1670 train_time:43817ms step_avg:97.15ms +step:452/1670 train_time:43912ms step_avg:97.15ms +step:453/1670 train_time:44006ms step_avg:97.14ms +step:454/1670 train_time:44101ms step_avg:97.14ms +step:455/1670 train_time:44197ms step_avg:97.14ms +step:456/1670 train_time:44293ms step_avg:97.13ms +step:457/1670 train_time:44389ms step_avg:97.13ms +step:458/1670 train_time:44486ms step_avg:97.13ms +step:459/1670 train_time:44583ms step_avg:97.13ms +step:460/1670 train_time:44678ms step_avg:97.13ms +step:461/1670 train_time:44773ms step_avg:97.12ms +step:462/1670 train_time:44868ms step_avg:97.12ms +step:463/1670 train_time:44962ms step_avg:97.11ms +step:464/1670 train_time:45057ms step_avg:97.11ms +step:465/1670 train_time:45153ms step_avg:97.10ms +step:466/1670 train_time:45249ms step_avg:97.10ms +step:467/1670 train_time:45345ms step_avg:97.10ms +step:468/1670 train_time:45440ms step_avg:97.09ms +step:469/1670 train_time:45537ms step_avg:97.09ms +step:470/1670 train_time:45634ms step_avg:97.09ms +step:471/1670 train_time:45729ms step_avg:97.09ms +step:472/1670 train_time:45824ms step_avg:97.08ms +step:473/1670 train_time:45919ms step_avg:97.08ms +step:474/1670 train_time:46014ms step_avg:97.08ms +step:475/1670 train_time:46110ms step_avg:97.07ms +step:476/1670 train_time:46205ms step_avg:97.07ms +step:477/1670 train_time:46300ms step_avg:97.07ms +step:478/1670 train_time:46395ms step_avg:97.06ms +step:479/1670 train_time:46491ms step_avg:97.06ms +step:480/1670 train_time:46588ms step_avg:97.06ms +step:481/1670 train_time:46683ms step_avg:97.05ms +step:482/1670 train_time:46779ms step_avg:97.05ms +step:483/1670 train_time:46875ms step_avg:97.05ms +step:484/1670 train_time:46970ms step_avg:97.05ms +step:485/1670 train_time:47065ms step_avg:97.04ms +step:486/1670 train_time:47160ms step_avg:97.04ms +step:487/1670 train_time:47256ms step_avg:97.04ms +step:488/1670 train_time:47353ms step_avg:97.03ms +step:489/1670 train_time:47449ms step_avg:97.03ms +step:490/1670 train_time:47544ms step_avg:97.03ms +step:491/1670 train_time:47639ms step_avg:97.02ms +step:492/1670 train_time:47735ms step_avg:97.02ms +step:493/1670 train_time:47830ms step_avg:97.02ms +step:494/1670 train_time:47925ms step_avg:97.01ms +step:495/1670 train_time:48020ms step_avg:97.01ms +step:496/1670 train_time:48116ms step_avg:97.01ms +step:497/1670 train_time:48212ms step_avg:97.01ms +step:498/1670 train_time:48307ms step_avg:97.00ms +step:499/1670 train_time:48402ms step_avg:97.00ms +step:500/1670 train_time:48498ms step_avg:97.00ms +step:500/1670 val_loss:3.7107 train_time:48594ms step_avg:97.19ms +step:501/1670 train_time:48615ms step_avg:97.04ms +step:502/1670 train_time:48697ms step_avg:97.01ms +step:503/1670 train_time:48795ms step_avg:97.01ms +step:504/1670 train_time:48891ms step_avg:97.01ms +step:505/1670 train_time:48985ms step_avg:97.00ms +step:506/1670 train_time:49080ms step_avg:97.00ms +step:507/1670 train_time:49174ms step_avg:96.99ms +step:508/1670 train_time:49269ms step_avg:96.99ms +step:509/1670 train_time:49364ms step_avg:96.98ms +step:510/1670 train_time:49459ms step_avg:96.98ms +step:511/1670 train_time:49553ms step_avg:96.97ms +step:512/1670 train_time:49651ms step_avg:96.97ms +step:513/1670 train_time:49749ms step_avg:96.98ms +step:514/1670 train_time:49848ms step_avg:96.98ms +step:515/1670 train_time:49944ms step_avg:96.98ms +step:516/1670 train_time:50040ms step_avg:96.98ms +step:517/1670 train_time:50134ms step_avg:96.97ms +step:518/1670 train_time:50229ms step_avg:96.97ms +step:519/1670 train_time:50324ms step_avg:96.96ms +step:520/1670 train_time:50419ms step_avg:96.96ms +step:521/1670 train_time:50514ms step_avg:96.96ms +step:522/1670 train_time:50610ms step_avg:96.95ms +step:523/1670 train_time:50707ms step_avg:96.95ms +step:524/1670 train_time:50805ms step_avg:96.96ms +step:525/1670 train_time:50901ms step_avg:96.95ms +step:526/1670 train_time:50996ms step_avg:96.95ms +step:527/1670 train_time:51091ms step_avg:96.95ms +step:528/1670 train_time:51187ms step_avg:96.94ms +step:529/1670 train_time:51281ms step_avg:96.94ms +step:530/1670 train_time:51376ms step_avg:96.94ms +step:531/1670 train_time:51471ms step_avg:96.93ms +step:532/1670 train_time:51567ms step_avg:96.93ms +step:533/1670 train_time:51663ms step_avg:96.93ms +step:534/1670 train_time:51760ms step_avg:96.93ms +step:535/1670 train_time:51856ms step_avg:96.93ms +step:536/1670 train_time:51951ms step_avg:96.92ms +step:537/1670 train_time:52047ms step_avg:96.92ms +step:538/1670 train_time:52142ms step_avg:96.92ms +step:539/1670 train_time:52237ms step_avg:96.92ms +step:540/1670 train_time:52332ms step_avg:96.91ms +step:541/1670 train_time:52428ms step_avg:96.91ms +step:542/1670 train_time:52523ms step_avg:96.91ms +step:543/1670 train_time:52619ms step_avg:96.90ms +step:544/1670 train_time:52714ms step_avg:96.90ms +step:545/1670 train_time:52810ms step_avg:96.90ms +step:546/1670 train_time:52906ms step_avg:96.90ms +step:547/1670 train_time:53003ms step_avg:96.90ms +step:548/1670 train_time:53099ms step_avg:96.90ms +step:549/1670 train_time:53194ms step_avg:96.89ms +step:550/1670 train_time:53289ms step_avg:96.89ms +step:551/1670 train_time:53384ms step_avg:96.89ms +step:552/1670 train_time:53480ms step_avg:96.88ms +step:553/1670 train_time:53575ms step_avg:96.88ms +step:554/1670 train_time:53670ms step_avg:96.88ms +step:555/1670 train_time:53766ms step_avg:96.88ms +step:556/1670 train_time:53862ms step_avg:96.87ms +step:557/1670 train_time:53958ms step_avg:96.87ms +step:558/1670 train_time:54054ms step_avg:96.87ms +step:559/1670 train_time:54150ms step_avg:96.87ms +step:560/1670 train_time:54247ms step_avg:96.87ms +step:561/1670 train_time:54344ms step_avg:96.87ms +step:562/1670 train_time:54441ms step_avg:96.87ms +step:563/1670 train_time:54537ms step_avg:96.87ms +step:564/1670 train_time:54634ms step_avg:96.87ms +step:565/1670 train_time:54730ms step_avg:96.87ms +step:566/1670 train_time:54827ms step_avg:96.87ms +step:567/1670 train_time:54925ms step_avg:96.87ms +step:568/1670 train_time:55023ms step_avg:96.87ms +step:569/1670 train_time:55119ms step_avg:96.87ms +step:570/1670 train_time:55216ms step_avg:96.87ms +step:571/1670 train_time:55312ms step_avg:96.87ms +step:572/1670 train_time:55409ms step_avg:96.87ms +step:573/1670 train_time:55508ms step_avg:96.87ms +step:574/1670 train_time:55606ms step_avg:96.87ms +step:575/1670 train_time:55703ms step_avg:96.88ms +step:576/1670 train_time:55800ms step_avg:96.87ms +step:577/1670 train_time:55897ms step_avg:96.87ms +step:578/1670 train_time:55993ms step_avg:96.87ms +step:579/1670 train_time:56090ms step_avg:96.87ms +step:580/1670 train_time:56188ms step_avg:96.88ms +step:581/1670 train_time:56286ms step_avg:96.88ms +step:582/1670 train_time:56383ms step_avg:96.88ms +step:583/1670 train_time:56480ms step_avg:96.88ms +step:584/1670 train_time:56576ms step_avg:96.88ms +step:585/1670 train_time:56673ms step_avg:96.88ms +step:586/1670 train_time:56770ms step_avg:96.88ms +step:587/1670 train_time:56867ms step_avg:96.88ms +step:588/1670 train_time:56965ms step_avg:96.88ms +step:589/1670 train_time:57061ms step_avg:96.88ms +step:590/1670 train_time:57158ms step_avg:96.88ms +step:591/1670 train_time:57254ms step_avg:96.88ms +step:592/1670 train_time:57350ms step_avg:96.87ms +step:593/1670 train_time:57447ms step_avg:96.88ms +step:594/1670 train_time:57546ms step_avg:96.88ms +step:595/1670 train_time:57645ms step_avg:96.88ms +step:596/1670 train_time:57741ms step_avg:96.88ms +step:597/1670 train_time:57838ms step_avg:96.88ms +step:598/1670 train_time:57934ms step_avg:96.88ms +step:599/1670 train_time:58030ms step_avg:96.88ms +step:600/1670 train_time:58129ms step_avg:96.88ms +step:601/1670 train_time:58227ms step_avg:96.88ms +step:602/1670 train_time:58324ms step_avg:96.88ms +step:603/1670 train_time:58420ms step_avg:96.88ms +step:604/1670 train_time:58517ms step_avg:96.88ms +step:605/1670 train_time:58614ms step_avg:96.88ms +step:606/1670 train_time:58711ms step_avg:96.88ms +step:607/1670 train_time:58809ms step_avg:96.88ms +step:608/1670 train_time:58907ms step_avg:96.89ms +step:609/1670 train_time:59004ms step_avg:96.89ms +step:610/1670 train_time:59100ms step_avg:96.89ms +step:611/1670 train_time:59197ms step_avg:96.89ms +step:612/1670 train_time:59293ms step_avg:96.88ms +step:613/1670 train_time:59390ms step_avg:96.88ms +step:614/1670 train_time:59487ms step_avg:96.88ms +step:615/1670 train_time:59585ms step_avg:96.89ms +step:616/1670 train_time:59681ms step_avg:96.89ms +step:617/1670 train_time:59778ms step_avg:96.89ms +step:618/1670 train_time:59875ms step_avg:96.89ms +step:619/1670 train_time:59972ms step_avg:96.89ms +step:620/1670 train_time:60069ms step_avg:96.89ms +step:621/1670 train_time:60166ms step_avg:96.89ms +step:622/1670 train_time:60263ms step_avg:96.89ms +step:623/1670 train_time:60360ms step_avg:96.89ms +step:624/1670 train_time:60457ms step_avg:96.89ms +step:625/1670 train_time:60554ms step_avg:96.89ms +step:625/1670 val_loss:3.6104 train_time:60650ms step_avg:97.04ms +step:626/1670 train_time:60671ms step_avg:96.92ms +step:627/1670 train_time:60759ms step_avg:96.90ms +step:628/1670 train_time:60858ms step_avg:96.91ms +step:629/1670 train_time:60955ms step_avg:96.91ms +step:630/1670 train_time:61050ms step_avg:96.91ms +step:631/1670 train_time:61146ms step_avg:96.90ms +step:632/1670 train_time:61242ms step_avg:96.90ms +step:633/1670 train_time:61338ms step_avg:96.90ms +step:634/1670 train_time:61433ms step_avg:96.90ms +step:635/1670 train_time:61529ms step_avg:96.90ms +step:636/1670 train_time:61628ms step_avg:96.90ms +step:637/1670 train_time:61730ms step_avg:96.91ms +step:638/1670 train_time:61830ms step_avg:96.91ms +step:639/1670 train_time:62201ms step_avg:97.34ms +step:640/1670 train_time:62304ms step_avg:97.35ms +step:641/1670 train_time:62399ms step_avg:97.35ms +step:642/1670 train_time:62495ms step_avg:97.34ms +step:643/1670 train_time:62590ms step_avg:97.34ms +step:644/1670 train_time:62687ms step_avg:97.34ms +step:645/1670 train_time:62783ms step_avg:97.34ms +step:646/1670 train_time:62878ms step_avg:97.33ms +step:647/1670 train_time:62973ms step_avg:97.33ms +step:648/1670 train_time:63069ms step_avg:97.33ms +step:649/1670 train_time:63166ms step_avg:97.33ms +step:650/1670 train_time:63268ms step_avg:97.34ms +step:651/1670 train_time:63368ms step_avg:97.34ms +step:652/1670 train_time:63466ms step_avg:97.34ms +step:653/1670 train_time:63563ms step_avg:97.34ms +step:654/1670 train_time:63659ms step_avg:97.34ms +step:655/1670 train_time:63754ms step_avg:97.33ms +step:656/1670 train_time:63850ms step_avg:97.33ms +step:657/1670 train_time:63948ms step_avg:97.33ms +step:658/1670 train_time:64045ms step_avg:97.33ms +step:659/1670 train_time:64141ms step_avg:97.33ms +step:660/1670 train_time:64240ms step_avg:97.33ms +step:661/1670 train_time:64338ms step_avg:97.34ms +step:662/1670 train_time:64436ms step_avg:97.34ms +step:663/1670 train_time:64533ms step_avg:97.33ms +step:664/1670 train_time:64630ms step_avg:97.33ms +step:665/1670 train_time:64726ms step_avg:97.33ms +step:666/1670 train_time:64822ms step_avg:97.33ms +step:667/1670 train_time:64918ms step_avg:97.33ms +step:668/1670 train_time:65014ms step_avg:97.33ms +step:669/1670 train_time:65110ms step_avg:97.33ms +step:670/1670 train_time:65207ms step_avg:97.32ms +step:671/1670 train_time:65306ms step_avg:97.33ms +step:672/1670 train_time:65404ms step_avg:97.33ms +step:673/1670 train_time:65503ms step_avg:97.33ms +step:674/1670 train_time:65600ms step_avg:97.33ms +step:675/1670 train_time:65698ms step_avg:97.33ms +step:676/1670 train_time:65793ms step_avg:97.33ms +step:677/1670 train_time:65889ms step_avg:97.33ms +step:678/1670 train_time:65986ms step_avg:97.32ms +step:679/1670 train_time:66082ms step_avg:97.32ms +step:680/1670 train_time:66178ms step_avg:97.32ms +step:681/1670 train_time:66275ms step_avg:97.32ms +step:682/1670 train_time:66373ms step_avg:97.32ms +step:683/1670 train_time:66470ms step_avg:97.32ms +step:684/1670 train_time:66568ms step_avg:97.32ms +step:685/1670 train_time:66666ms step_avg:97.32ms +step:686/1670 train_time:66764ms step_avg:97.32ms +step:687/1670 train_time:66861ms step_avg:97.32ms +step:688/1670 train_time:66956ms step_avg:97.32ms +step:689/1670 train_time:67052ms step_avg:97.32ms +step:690/1670 train_time:67149ms step_avg:97.32ms +step:691/1670 train_time:67246ms step_avg:97.32ms +step:692/1670 train_time:67344ms step_avg:97.32ms +step:693/1670 train_time:67440ms step_avg:97.32ms +step:694/1670 train_time:67538ms step_avg:97.32ms +step:695/1670 train_time:67634ms step_avg:97.32ms +step:696/1670 train_time:67731ms step_avg:97.31ms +step:697/1670 train_time:67828ms step_avg:97.31ms +step:698/1670 train_time:67926ms step_avg:97.31ms +step:699/1670 train_time:68022ms step_avg:97.31ms +step:700/1670 train_time:68119ms step_avg:97.31ms +step:701/1670 train_time:68215ms step_avg:97.31ms +step:702/1670 train_time:68312ms step_avg:97.31ms +step:703/1670 train_time:68408ms step_avg:97.31ms +step:704/1670 train_time:68506ms step_avg:97.31ms +step:705/1670 train_time:68603ms step_avg:97.31ms +step:706/1670 train_time:68700ms step_avg:97.31ms +step:707/1670 train_time:68798ms step_avg:97.31ms +step:708/1670 train_time:68894ms step_avg:97.31ms +step:709/1670 train_time:68991ms step_avg:97.31ms +step:710/1670 train_time:69088ms step_avg:97.31ms +step:711/1670 train_time:69186ms step_avg:97.31ms +step:712/1670 train_time:69284ms step_avg:97.31ms +step:713/1670 train_time:69380ms step_avg:97.31ms +step:714/1670 train_time:69477ms step_avg:97.31ms +step:715/1670 train_time:69573ms step_avg:97.30ms +step:716/1670 train_time:69669ms step_avg:97.30ms +step:717/1670 train_time:69768ms step_avg:97.31ms +step:718/1670 train_time:69865ms step_avg:97.31ms +step:719/1670 train_time:69961ms step_avg:97.30ms +step:720/1670 train_time:70058ms step_avg:97.30ms +step:721/1670 train_time:70155ms step_avg:97.30ms +step:722/1670 train_time:70253ms step_avg:97.30ms +step:723/1670 train_time:70350ms step_avg:97.30ms +step:724/1670 train_time:70447ms step_avg:97.30ms +step:725/1670 train_time:70545ms step_avg:97.30ms +step:726/1670 train_time:70642ms step_avg:97.30ms +step:727/1670 train_time:70739ms step_avg:97.30ms +step:728/1670 train_time:70836ms step_avg:97.30ms +step:729/1670 train_time:70932ms step_avg:97.30ms +step:730/1670 train_time:71029ms step_avg:97.30ms +step:731/1670 train_time:71127ms step_avg:97.30ms +step:732/1670 train_time:71225ms step_avg:97.30ms +step:733/1670 train_time:71323ms step_avg:97.30ms +step:734/1670 train_time:71420ms step_avg:97.30ms +step:735/1670 train_time:71517ms step_avg:97.30ms +step:736/1670 train_time:71614ms step_avg:97.30ms +step:737/1670 train_time:71710ms step_avg:97.30ms +step:738/1670 train_time:71808ms step_avg:97.30ms +step:739/1670 train_time:71906ms step_avg:97.30ms +step:740/1670 train_time:72003ms step_avg:97.30ms +step:741/1670 train_time:72100ms step_avg:97.30ms +step:742/1670 train_time:72196ms step_avg:97.30ms +step:743/1670 train_time:72293ms step_avg:97.30ms +step:744/1670 train_time:72390ms step_avg:97.30ms +step:745/1670 train_time:72488ms step_avg:97.30ms +step:746/1670 train_time:72584ms step_avg:97.30ms +step:747/1670 train_time:72682ms step_avg:97.30ms +step:748/1670 train_time:72778ms step_avg:97.30ms +step:749/1670 train_time:72876ms step_avg:97.30ms +step:750/1670 train_time:72973ms step_avg:97.30ms +step:750/1670 val_loss:3.5576 train_time:73069ms step_avg:97.43ms +step:751/1670 train_time:73089ms step_avg:97.32ms +step:752/1670 train_time:73172ms step_avg:97.30ms +step:753/1670 train_time:73269ms step_avg:97.30ms +step:754/1670 train_time:73364ms step_avg:97.30ms +step:755/1670 train_time:73460ms step_avg:97.30ms +step:756/1670 train_time:73556ms step_avg:97.30ms +step:757/1670 train_time:73653ms step_avg:97.30ms +step:758/1670 train_time:73750ms step_avg:97.29ms +step:759/1670 train_time:73847ms step_avg:97.30ms +step:760/1670 train_time:73943ms step_avg:97.29ms +step:761/1670 train_time:74043ms step_avg:97.30ms +step:762/1670 train_time:74144ms step_avg:97.30ms +step:763/1670 train_time:74243ms step_avg:97.30ms +step:764/1670 train_time:74341ms step_avg:97.30ms +step:765/1670 train_time:74438ms step_avg:97.30ms +step:766/1670 train_time:74534ms step_avg:97.30ms +step:767/1670 train_time:74630ms step_avg:97.30ms +step:768/1670 train_time:74726ms step_avg:97.30ms +step:769/1670 train_time:74824ms step_avg:97.30ms +step:770/1670 train_time:74921ms step_avg:97.30ms +step:771/1670 train_time:75020ms step_avg:97.30ms +step:772/1670 train_time:75121ms step_avg:97.31ms +step:773/1670 train_time:75221ms step_avg:97.31ms +step:774/1670 train_time:75319ms step_avg:97.31ms +step:775/1670 train_time:75416ms step_avg:97.31ms +step:776/1670 train_time:75512ms step_avg:97.31ms +step:777/1670 train_time:75608ms step_avg:97.31ms +step:778/1670 train_time:75704ms step_avg:97.31ms +step:779/1670 train_time:75801ms step_avg:97.31ms +step:780/1670 train_time:75898ms step_avg:97.31ms +step:781/1670 train_time:75997ms step_avg:97.31ms +step:782/1670 train_time:76095ms step_avg:97.31ms +step:783/1670 train_time:76192ms step_avg:97.31ms +step:784/1670 train_time:76289ms step_avg:97.31ms +step:785/1670 train_time:76385ms step_avg:97.31ms +step:786/1670 train_time:76483ms step_avg:97.31ms +step:787/1670 train_time:76580ms step_avg:97.31ms +step:788/1670 train_time:76678ms step_avg:97.31ms +step:789/1670 train_time:76774ms step_avg:97.31ms +step:790/1670 train_time:76871ms step_avg:97.30ms +step:791/1670 train_time:76967ms step_avg:97.30ms +step:792/1670 train_time:77065ms step_avg:97.30ms +step:793/1670 train_time:77163ms step_avg:97.30ms +step:794/1670 train_time:77262ms step_avg:97.31ms +step:795/1670 train_time:77359ms step_avg:97.31ms +step:796/1670 train_time:77456ms step_avg:97.31ms +step:797/1670 train_time:77553ms step_avg:97.31ms +step:798/1670 train_time:77649ms step_avg:97.30ms +step:799/1670 train_time:77745ms step_avg:97.30ms +step:800/1670 train_time:77843ms step_avg:97.30ms +step:801/1670 train_time:77941ms step_avg:97.30ms +step:802/1670 train_time:78038ms step_avg:97.30ms +step:803/1670 train_time:78136ms step_avg:97.31ms +step:804/1670 train_time:78233ms step_avg:97.30ms +step:805/1670 train_time:78331ms step_avg:97.31ms +step:806/1670 train_time:78427ms step_avg:97.30ms +step:807/1670 train_time:78524ms step_avg:97.30ms +step:808/1670 train_time:78621ms step_avg:97.30ms +step:809/1670 train_time:78718ms step_avg:97.30ms +step:810/1670 train_time:78815ms step_avg:97.30ms +step:811/1670 train_time:78912ms step_avg:97.30ms +step:812/1670 train_time:79008ms step_avg:97.30ms +step:813/1670 train_time:79106ms step_avg:97.30ms +step:814/1670 train_time:79204ms step_avg:97.30ms +step:815/1670 train_time:79303ms step_avg:97.30ms +step:816/1670 train_time:79402ms step_avg:97.31ms +step:817/1670 train_time:79499ms step_avg:97.31ms +step:818/1670 train_time:79595ms step_avg:97.30ms +step:819/1670 train_time:79692ms step_avg:97.30ms +step:820/1670 train_time:79788ms step_avg:97.30ms +step:821/1670 train_time:79884ms step_avg:97.30ms +step:822/1670 train_time:79982ms step_avg:97.30ms +step:823/1670 train_time:80080ms step_avg:97.30ms +step:824/1670 train_time:80178ms step_avg:97.30ms +step:825/1670 train_time:80275ms step_avg:97.30ms +step:826/1670 train_time:80372ms step_avg:97.30ms +step:827/1670 train_time:80469ms step_avg:97.30ms +step:828/1670 train_time:80565ms step_avg:97.30ms +step:829/1670 train_time:80662ms step_avg:97.30ms +step:830/1670 train_time:80760ms step_avg:97.30ms +step:831/1670 train_time:80856ms step_avg:97.30ms +step:832/1670 train_time:80953ms step_avg:97.30ms +step:833/1670 train_time:81050ms step_avg:97.30ms +step:834/1670 train_time:81146ms step_avg:97.30ms +step:835/1670 train_time:81244ms step_avg:97.30ms +step:836/1670 train_time:81344ms step_avg:97.30ms +step:837/1670 train_time:81442ms step_avg:97.30ms +step:838/1670 train_time:81539ms step_avg:97.30ms +step:839/1670 train_time:81637ms step_avg:97.30ms +step:840/1670 train_time:81734ms step_avg:97.30ms +step:841/1670 train_time:81831ms step_avg:97.30ms +step:842/1670 train_time:81927ms step_avg:97.30ms +step:843/1670 train_time:82024ms step_avg:97.30ms +step:844/1670 train_time:82120ms step_avg:97.30ms +step:845/1670 train_time:82217ms step_avg:97.30ms +step:846/1670 train_time:82314ms step_avg:97.30ms +step:847/1670 train_time:82411ms step_avg:97.30ms +step:848/1670 train_time:82508ms step_avg:97.30ms +step:849/1670 train_time:82606ms step_avg:97.30ms +step:850/1670 train_time:82702ms step_avg:97.30ms +step:851/1670 train_time:83014ms step_avg:97.55ms +step:852/1670 train_time:83090ms step_avg:97.52ms +step:853/1670 train_time:83185ms step_avg:97.52ms +step:854/1670 train_time:83281ms step_avg:97.52ms +step:855/1670 train_time:83377ms step_avg:97.52ms +step:856/1670 train_time:83472ms step_avg:97.51ms +step:857/1670 train_time:83568ms step_avg:97.51ms +step:858/1670 train_time:83664ms step_avg:97.51ms +step:859/1670 train_time:83761ms step_avg:97.51ms +step:860/1670 train_time:83857ms step_avg:97.51ms +step:861/1670 train_time:83955ms step_avg:97.51ms +step:862/1670 train_time:84058ms step_avg:97.52ms +step:863/1670 train_time:84157ms step_avg:97.52ms +step:864/1670 train_time:84254ms step_avg:97.52ms +step:865/1670 train_time:84350ms step_avg:97.51ms +step:866/1670 train_time:84445ms step_avg:97.51ms +step:867/1670 train_time:84542ms step_avg:97.51ms +step:868/1670 train_time:84638ms step_avg:97.51ms +step:869/1670 train_time:84734ms step_avg:97.51ms +step:870/1670 train_time:84829ms step_avg:97.51ms +step:871/1670 train_time:84926ms step_avg:97.50ms +step:872/1670 train_time:85025ms step_avg:97.51ms +step:873/1670 train_time:85124ms step_avg:97.51ms +step:874/1670 train_time:85222ms step_avg:97.51ms +step:875/1670 train_time:85320ms step_avg:97.51ms +step:875/1670 val_loss:3.5173 train_time:85416ms step_avg:97.62ms +step:876/1670 train_time:85436ms step_avg:97.53ms +step:877/1670 train_time:85522ms step_avg:97.52ms +step:878/1670 train_time:85624ms step_avg:97.52ms +step:879/1670 train_time:85722ms step_avg:97.52ms +step:880/1670 train_time:85818ms step_avg:97.52ms +step:881/1670 train_time:85915ms step_avg:97.52ms +step:882/1670 train_time:86010ms step_avg:97.52ms +step:883/1670 train_time:86106ms step_avg:97.52ms +step:884/1670 train_time:86202ms step_avg:97.51ms +step:885/1670 train_time:86299ms step_avg:97.51ms +step:886/1670 train_time:86397ms step_avg:97.51ms +step:887/1670 train_time:86496ms step_avg:97.52ms +step:888/1670 train_time:86596ms step_avg:97.52ms +step:889/1670 train_time:86694ms step_avg:97.52ms +step:890/1670 train_time:86790ms step_avg:97.52ms +step:891/1670 train_time:86886ms step_avg:97.52ms +step:892/1670 train_time:86982ms step_avg:97.51ms +step:893/1670 train_time:87079ms step_avg:97.51ms +step:894/1670 train_time:87175ms step_avg:97.51ms +step:895/1670 train_time:87271ms step_avg:97.51ms +step:896/1670 train_time:87368ms step_avg:97.51ms +step:897/1670 train_time:87465ms step_avg:97.51ms +step:898/1670 train_time:87564ms step_avg:97.51ms +step:899/1670 train_time:87662ms step_avg:97.51ms +step:900/1670 train_time:87761ms step_avg:97.51ms +step:901/1670 train_time:87860ms step_avg:97.51ms +step:902/1670 train_time:87957ms step_avg:97.51ms +step:903/1670 train_time:88053ms step_avg:97.51ms +step:904/1670 train_time:88149ms step_avg:97.51ms +step:905/1670 train_time:88245ms step_avg:97.51ms +step:906/1670 train_time:88342ms step_avg:97.51ms +step:907/1670 train_time:88439ms step_avg:97.51ms +step:908/1670 train_time:88537ms step_avg:97.51ms +step:909/1670 train_time:88634ms step_avg:97.51ms +step:910/1670 train_time:88733ms step_avg:97.51ms +step:911/1670 train_time:88831ms step_avg:97.51ms +step:912/1670 train_time:88929ms step_avg:97.51ms +step:913/1670 train_time:89025ms step_avg:97.51ms +step:914/1670 train_time:89122ms step_avg:97.51ms +step:915/1670 train_time:89219ms step_avg:97.51ms +step:916/1670 train_time:89315ms step_avg:97.51ms +step:917/1670 train_time:89412ms step_avg:97.50ms +step:918/1670 train_time:89509ms step_avg:97.50ms +step:919/1670 train_time:89605ms step_avg:97.50ms +step:920/1670 train_time:89703ms step_avg:97.50ms +step:921/1670 train_time:89801ms step_avg:97.50ms +step:922/1670 train_time:89900ms step_avg:97.51ms +step:923/1670 train_time:89998ms step_avg:97.51ms +step:924/1670 train_time:90096ms step_avg:97.51ms +step:925/1670 train_time:90192ms step_avg:97.51ms +step:926/1670 train_time:90288ms step_avg:97.50ms +step:927/1670 train_time:90384ms step_avg:97.50ms +step:928/1670 train_time:90481ms step_avg:97.50ms +step:929/1670 train_time:90579ms step_avg:97.50ms +step:930/1670 train_time:90677ms step_avg:97.50ms +step:931/1670 train_time:90774ms step_avg:97.50ms +step:932/1670 train_time:90872ms step_avg:97.50ms +step:933/1670 train_time:90969ms step_avg:97.50ms +step:934/1670 train_time:91066ms step_avg:97.50ms +step:935/1670 train_time:91163ms step_avg:97.50ms +step:936/1670 train_time:91261ms step_avg:97.50ms +step:937/1670 train_time:91358ms step_avg:97.50ms +step:938/1670 train_time:91455ms step_avg:97.50ms +step:939/1670 train_time:91551ms step_avg:97.50ms +step:940/1670 train_time:91647ms step_avg:97.50ms +step:941/1670 train_time:91744ms step_avg:97.50ms +step:942/1670 train_time:91843ms step_avg:97.50ms +step:943/1670 train_time:91940ms step_avg:97.50ms +step:944/1670 train_time:92038ms step_avg:97.50ms +step:945/1670 train_time:92134ms step_avg:97.50ms +step:946/1670 train_time:92232ms step_avg:97.50ms +step:947/1670 train_time:92328ms step_avg:97.50ms +step:948/1670 train_time:92425ms step_avg:97.50ms +step:949/1670 train_time:92523ms step_avg:97.50ms +step:950/1670 train_time:92621ms step_avg:97.50ms +step:951/1670 train_time:92719ms step_avg:97.50ms +step:952/1670 train_time:92816ms step_avg:97.50ms +step:953/1670 train_time:92913ms step_avg:97.50ms +step:954/1670 train_time:93012ms step_avg:97.50ms +step:955/1670 train_time:93108ms step_avg:97.49ms +step:956/1670 train_time:93204ms step_avg:97.49ms +step:957/1670 train_time:93301ms step_avg:97.49ms +step:958/1670 train_time:93398ms step_avg:97.49ms +step:959/1670 train_time:93495ms step_avg:97.49ms +step:960/1670 train_time:93592ms step_avg:97.49ms +step:961/1670 train_time:93689ms step_avg:97.49ms +step:962/1670 train_time:93787ms step_avg:97.49ms +step:963/1670 train_time:93885ms step_avg:97.49ms +step:964/1670 train_time:93983ms step_avg:97.49ms +step:965/1670 train_time:94080ms step_avg:97.49ms +step:966/1670 train_time:94179ms step_avg:97.49ms +step:967/1670 train_time:94276ms step_avg:97.49ms +step:968/1670 train_time:94373ms step_avg:97.49ms +step:969/1670 train_time:94469ms step_avg:97.49ms +step:970/1670 train_time:94565ms step_avg:97.49ms +step:971/1670 train_time:94663ms step_avg:97.49ms +step:972/1670 train_time:94760ms step_avg:97.49ms +step:973/1670 train_time:94858ms step_avg:97.49ms +step:974/1670 train_time:94955ms step_avg:97.49ms +step:975/1670 train_time:95053ms step_avg:97.49ms +step:976/1670 train_time:95150ms step_avg:97.49ms +step:977/1670 train_time:95246ms step_avg:97.49ms +step:978/1670 train_time:95342ms step_avg:97.49ms +step:979/1670 train_time:95440ms step_avg:97.49ms +step:980/1670 train_time:95537ms step_avg:97.49ms +step:981/1670 train_time:95634ms step_avg:97.49ms +step:982/1670 train_time:95732ms step_avg:97.49ms +step:983/1670 train_time:95829ms step_avg:97.49ms +step:984/1670 train_time:95926ms step_avg:97.49ms +step:985/1670 train_time:96022ms step_avg:97.48ms +step:986/1670 train_time:96120ms step_avg:97.48ms +step:987/1670 train_time:96218ms step_avg:97.49ms +step:988/1670 train_time:96315ms step_avg:97.48ms +step:989/1670 train_time:96412ms step_avg:97.48ms +step:990/1670 train_time:96508ms step_avg:97.48ms +step:991/1670 train_time:96604ms step_avg:97.48ms +step:992/1670 train_time:96702ms step_avg:97.48ms +step:993/1670 train_time:96800ms step_avg:97.48ms +step:994/1670 train_time:96898ms step_avg:97.48ms +step:995/1670 train_time:96996ms step_avg:97.48ms +step:996/1670 train_time:97094ms step_avg:97.48ms +step:997/1670 train_time:97190ms step_avg:97.48ms +step:998/1670 train_time:97287ms step_avg:97.48ms +step:999/1670 train_time:97384ms step_avg:97.48ms +step:1000/1670 train_time:97481ms step_avg:97.48ms +step:1000/1670 val_loss:3.4755 train_time:97578ms step_avg:97.58ms +step:1001/1670 train_time:97599ms step_avg:97.50ms +step:1002/1670 train_time:97679ms step_avg:97.48ms +step:1003/1670 train_time:97777ms step_avg:97.48ms +step:1004/1670 train_time:97874ms step_avg:97.48ms +step:1005/1670 train_time:97969ms step_avg:97.48ms +step:1006/1670 train_time:98066ms step_avg:97.48ms +step:1007/1670 train_time:98162ms step_avg:97.48ms +step:1008/1670 train_time:98257ms step_avg:97.48ms +step:1009/1670 train_time:98353ms step_avg:97.48ms +step:1010/1670 train_time:98449ms step_avg:97.47ms +step:1011/1670 train_time:98549ms step_avg:97.48ms +step:1012/1670 train_time:98649ms step_avg:97.48ms +step:1013/1670 train_time:98749ms step_avg:97.48ms +step:1014/1670 train_time:98847ms step_avg:97.48ms +step:1015/1670 train_time:98944ms step_avg:97.48ms +step:1016/1670 train_time:99040ms step_avg:97.48ms +step:1017/1670 train_time:99136ms step_avg:97.48ms +step:1018/1670 train_time:99232ms step_avg:97.48ms +step:1019/1670 train_time:99329ms step_avg:97.48ms +step:1020/1670 train_time:99426ms step_avg:97.48ms +step:1021/1670 train_time:99522ms step_avg:97.48ms +step:1022/1670 train_time:99620ms step_avg:97.48ms +step:1023/1670 train_time:99719ms step_avg:97.48ms +step:1024/1670 train_time:99817ms step_avg:97.48ms +step:1025/1670 train_time:99914ms step_avg:97.48ms +step:1026/1670 train_time:100011ms step_avg:97.48ms +step:1027/1670 train_time:100107ms step_avg:97.48ms +step:1028/1670 train_time:100204ms step_avg:97.47ms +step:1029/1670 train_time:100300ms step_avg:97.47ms +step:1030/1670 train_time:100396ms step_avg:97.47ms +step:1031/1670 train_time:100493ms step_avg:97.47ms +step:1032/1670 train_time:100591ms step_avg:97.47ms +step:1033/1670 train_time:100689ms step_avg:97.47ms +step:1034/1670 train_time:100789ms step_avg:97.47ms +step:1035/1670 train_time:100887ms step_avg:97.48ms +step:1036/1670 train_time:100984ms step_avg:97.47ms +step:1037/1670 train_time:101081ms step_avg:97.47ms +step:1038/1670 train_time:101177ms step_avg:97.47ms +step:1039/1670 train_time:101273ms step_avg:97.47ms +step:1040/1670 train_time:101370ms step_avg:97.47ms +step:1041/1670 train_time:101467ms step_avg:97.47ms +step:1042/1670 train_time:101564ms step_avg:97.47ms +step:1043/1670 train_time:101661ms step_avg:97.47ms +step:1044/1670 train_time:101759ms step_avg:97.47ms +step:1045/1670 train_time:101856ms step_avg:97.47ms +step:1046/1670 train_time:101953ms step_avg:97.47ms +step:1047/1670 train_time:102051ms step_avg:97.47ms +step:1048/1670 train_time:102147ms step_avg:97.47ms +step:1049/1670 train_time:102244ms step_avg:97.47ms +step:1050/1670 train_time:102341ms step_avg:97.47ms +step:1051/1670 train_time:102439ms step_avg:97.47ms +step:1052/1670 train_time:102535ms step_avg:97.47ms +step:1053/1670 train_time:102631ms step_avg:97.47ms +step:1054/1670 train_time:102729ms step_avg:97.47ms +step:1055/1670 train_time:102827ms step_avg:97.47ms +step:1056/1670 train_time:102925ms step_avg:97.47ms +step:1057/1670 train_time:103022ms step_avg:97.47ms +step:1058/1670 train_time:103120ms step_avg:97.47ms +step:1059/1670 train_time:103217ms step_avg:97.47ms +step:1060/1670 train_time:103313ms step_avg:97.46ms +step:1061/1670 train_time:103410ms step_avg:97.46ms +step:1062/1670 train_time:103685ms step_avg:97.63ms +step:1063/1670 train_time:103760ms step_avg:97.61ms +step:1064/1670 train_time:103856ms step_avg:97.61ms +step:1065/1670 train_time:103952ms step_avg:97.61ms +step:1066/1670 train_time:104048ms step_avg:97.61ms +step:1067/1670 train_time:104143ms step_avg:97.60ms +step:1068/1670 train_time:104240ms step_avg:97.60ms +step:1069/1670 train_time:104335ms step_avg:97.60ms +step:1070/1670 train_time:104431ms step_avg:97.60ms +step:1071/1670 train_time:104527ms step_avg:97.60ms +step:1072/1670 train_time:104632ms step_avg:97.60ms +step:1073/1670 train_time:104731ms step_avg:97.61ms +step:1074/1670 train_time:104829ms step_avg:97.61ms +step:1075/1670 train_time:104928ms step_avg:97.61ms +step:1076/1670 train_time:105024ms step_avg:97.61ms +step:1077/1670 train_time:105120ms step_avg:97.60ms +step:1078/1670 train_time:105216ms step_avg:97.60ms +step:1079/1670 train_time:105312ms step_avg:97.60ms +step:1080/1670 train_time:105408ms step_avg:97.60ms +step:1081/1670 train_time:105505ms step_avg:97.60ms +step:1082/1670 train_time:105604ms step_avg:97.60ms +step:1083/1670 train_time:105704ms step_avg:97.60ms +step:1084/1670 train_time:105803ms step_avg:97.60ms +step:1085/1670 train_time:105901ms step_avg:97.60ms +step:1086/1670 train_time:105997ms step_avg:97.60ms +step:1087/1670 train_time:106094ms step_avg:97.60ms +step:1088/1670 train_time:106190ms step_avg:97.60ms +step:1089/1670 train_time:106287ms step_avg:97.60ms +step:1090/1670 train_time:106383ms step_avg:97.60ms +step:1091/1670 train_time:106479ms step_avg:97.60ms +step:1092/1670 train_time:106576ms step_avg:97.60ms +step:1093/1670 train_time:106673ms step_avg:97.60ms +step:1094/1670 train_time:106771ms step_avg:97.60ms +step:1095/1670 train_time:106870ms step_avg:97.60ms +step:1096/1670 train_time:106967ms step_avg:97.60ms +step:1097/1670 train_time:107065ms step_avg:97.60ms +step:1098/1670 train_time:107161ms step_avg:97.60ms +step:1099/1670 train_time:107257ms step_avg:97.59ms +step:1100/1670 train_time:107352ms step_avg:97.59ms +step:1101/1670 train_time:107449ms step_avg:97.59ms +step:1102/1670 train_time:107547ms step_avg:97.59ms +step:1103/1670 train_time:107645ms step_avg:97.59ms +step:1104/1670 train_time:107742ms step_avg:97.59ms +step:1105/1670 train_time:107841ms step_avg:97.59ms +step:1106/1670 train_time:107938ms step_avg:97.59ms +step:1107/1670 train_time:108035ms step_avg:97.59ms +step:1108/1670 train_time:108132ms step_avg:97.59ms +step:1109/1670 train_time:108229ms step_avg:97.59ms +step:1110/1670 train_time:108325ms step_avg:97.59ms +step:1111/1670 train_time:108421ms step_avg:97.59ms +step:1112/1670 train_time:108518ms step_avg:97.59ms +step:1113/1670 train_time:108614ms step_avg:97.59ms +step:1114/1670 train_time:108711ms step_avg:97.59ms +step:1115/1670 train_time:108808ms step_avg:97.59ms +step:1116/1670 train_time:108906ms step_avg:97.59ms +step:1117/1670 train_time:109005ms step_avg:97.59ms +step:1118/1670 train_time:109106ms step_avg:97.59ms +step:1119/1670 train_time:109204ms step_avg:97.59ms +step:1120/1670 train_time:109301ms step_avg:97.59ms +step:1121/1670 train_time:109399ms step_avg:97.59ms +step:1122/1670 train_time:109496ms step_avg:97.59ms +step:1123/1670 train_time:109593ms step_avg:97.59ms +step:1124/1670 train_time:109691ms step_avg:97.59ms +step:1125/1670 train_time:109789ms step_avg:97.59ms +step:1125/1670 val_loss:3.4209 train_time:109887ms step_avg:97.68ms +step:1126/1670 train_time:109909ms step_avg:97.61ms +step:1127/1670 train_time:109995ms step_avg:97.60ms +step:1128/1670 train_time:110094ms step_avg:97.60ms +step:1129/1670 train_time:110192ms step_avg:97.60ms +step:1130/1670 train_time:110288ms step_avg:97.60ms +step:1131/1670 train_time:110385ms step_avg:97.60ms +step:1132/1670 train_time:110481ms step_avg:97.60ms +step:1133/1670 train_time:110578ms step_avg:97.60ms +step:1134/1670 train_time:110674ms step_avg:97.60ms +step:1135/1670 train_time:110772ms step_avg:97.60ms +step:1136/1670 train_time:110873ms step_avg:97.60ms +step:1137/1670 train_time:110974ms step_avg:97.60ms +step:1138/1670 train_time:111073ms step_avg:97.60ms +step:1139/1670 train_time:111171ms step_avg:97.60ms +step:1140/1670 train_time:111269ms step_avg:97.60ms +step:1141/1670 train_time:111366ms step_avg:97.60ms +step:1142/1670 train_time:111462ms step_avg:97.60ms +step:1143/1670 train_time:111558ms step_avg:97.60ms +step:1144/1670 train_time:111655ms step_avg:97.60ms +step:1145/1670 train_time:111753ms step_avg:97.60ms +step:1146/1670 train_time:111851ms step_avg:97.60ms +step:1147/1670 train_time:111952ms step_avg:97.60ms +step:1148/1670 train_time:112053ms step_avg:97.61ms +step:1149/1670 train_time:112151ms step_avg:97.61ms +step:1150/1670 train_time:112249ms step_avg:97.61ms +step:1151/1670 train_time:112345ms step_avg:97.61ms +step:1152/1670 train_time:112442ms step_avg:97.61ms +step:1153/1670 train_time:112538ms step_avg:97.60ms +step:1154/1670 train_time:112635ms step_avg:97.60ms +step:1155/1670 train_time:112733ms step_avg:97.60ms +step:1156/1670 train_time:112832ms step_avg:97.61ms +step:1157/1670 train_time:112931ms step_avg:97.61ms +step:1158/1670 train_time:113030ms step_avg:97.61ms +step:1159/1670 train_time:113128ms step_avg:97.61ms +step:1160/1670 train_time:113226ms step_avg:97.61ms +step:1161/1670 train_time:113324ms step_avg:97.61ms +step:1162/1670 train_time:113421ms step_avg:97.61ms +step:1163/1670 train_time:113518ms step_avg:97.61ms +step:1164/1670 train_time:113614ms step_avg:97.61ms +step:1165/1670 train_time:113713ms step_avg:97.61ms +step:1166/1670 train_time:113811ms step_avg:97.61ms +step:1167/1670 train_time:113909ms step_avg:97.61ms +step:1168/1670 train_time:114008ms step_avg:97.61ms +step:1169/1670 train_time:114106ms step_avg:97.61ms +step:1170/1670 train_time:114204ms step_avg:97.61ms +step:1171/1670 train_time:114301ms step_avg:97.61ms +step:1172/1670 train_time:114398ms step_avg:97.61ms +step:1173/1670 train_time:114495ms step_avg:97.61ms +step:1174/1670 train_time:114592ms step_avg:97.61ms +step:1175/1670 train_time:114689ms step_avg:97.61ms +step:1176/1670 train_time:114787ms step_avg:97.61ms +step:1177/1670 train_time:114885ms step_avg:97.61ms +step:1178/1670 train_time:114983ms step_avg:97.61ms +step:1179/1670 train_time:115081ms step_avg:97.61ms +step:1180/1670 train_time:115178ms step_avg:97.61ms +step:1181/1670 train_time:115277ms step_avg:97.61ms +step:1182/1670 train_time:115376ms step_avg:97.61ms +step:1183/1670 train_time:115473ms step_avg:97.61ms +step:1184/1670 train_time:115571ms step_avg:97.61ms +step:1185/1670 train_time:115669ms step_avg:97.61ms +step:1186/1670 train_time:115766ms step_avg:97.61ms +step:1187/1670 train_time:115864ms step_avg:97.61ms +step:1188/1670 train_time:115961ms step_avg:97.61ms +step:1189/1670 train_time:116059ms step_avg:97.61ms +step:1190/1670 train_time:116157ms step_avg:97.61ms +step:1191/1670 train_time:116254ms step_avg:97.61ms +step:1192/1670 train_time:116353ms step_avg:97.61ms +step:1193/1670 train_time:116451ms step_avg:97.61ms +step:1194/1670 train_time:116548ms step_avg:97.61ms +step:1195/1670 train_time:116646ms step_avg:97.61ms +step:1196/1670 train_time:116743ms step_avg:97.61ms +step:1197/1670 train_time:116841ms step_avg:97.61ms +step:1198/1670 train_time:116938ms step_avg:97.61ms +step:1199/1670 train_time:117037ms step_avg:97.61ms +step:1200/1670 train_time:117134ms step_avg:97.61ms +step:1201/1670 train_time:117233ms step_avg:97.61ms +step:1202/1670 train_time:117330ms step_avg:97.61ms +step:1203/1670 train_time:117428ms step_avg:97.61ms +step:1204/1670 train_time:117525ms step_avg:97.61ms +step:1205/1670 train_time:117623ms step_avg:97.61ms +step:1206/1670 train_time:117721ms step_avg:97.61ms +step:1207/1670 train_time:117818ms step_avg:97.61ms +step:1208/1670 train_time:117916ms step_avg:97.61ms +step:1209/1670 train_time:118014ms step_avg:97.61ms +step:1210/1670 train_time:118112ms step_avg:97.61ms +step:1211/1670 train_time:118209ms step_avg:97.61ms +step:1212/1670 train_time:118306ms step_avg:97.61ms +step:1213/1670 train_time:118403ms step_avg:97.61ms +step:1214/1670 train_time:118500ms step_avg:97.61ms +step:1215/1670 train_time:118599ms step_avg:97.61ms +step:1216/1670 train_time:118697ms step_avg:97.61ms +step:1217/1670 train_time:118795ms step_avg:97.61ms +step:1218/1670 train_time:118893ms step_avg:97.61ms +step:1219/1670 train_time:118991ms step_avg:97.61ms +step:1220/1670 train_time:119088ms step_avg:97.61ms +step:1221/1670 train_time:119186ms step_avg:97.61ms +step:1222/1670 train_time:119283ms step_avg:97.61ms +step:1223/1670 train_time:119380ms step_avg:97.61ms +step:1224/1670 train_time:119477ms step_avg:97.61ms +step:1225/1670 train_time:119575ms step_avg:97.61ms +step:1226/1670 train_time:119673ms step_avg:97.61ms +step:1227/1670 train_time:119771ms step_avg:97.61ms +step:1228/1670 train_time:119870ms step_avg:97.61ms +step:1229/1670 train_time:119967ms step_avg:97.61ms +step:1230/1670 train_time:120065ms step_avg:97.61ms +step:1231/1670 train_time:120162ms step_avg:97.61ms +step:1232/1670 train_time:120259ms step_avg:97.61ms +step:1233/1670 train_time:120356ms step_avg:97.61ms +step:1234/1670 train_time:120454ms step_avg:97.61ms +step:1235/1670 train_time:120554ms step_avg:97.61ms +step:1236/1670 train_time:120653ms step_avg:97.62ms +step:1237/1670 train_time:120751ms step_avg:97.62ms +step:1238/1670 train_time:120849ms step_avg:97.62ms +step:1239/1670 train_time:120947ms step_avg:97.62ms +step:1240/1670 train_time:121044ms step_avg:97.62ms +step:1241/1670 train_time:121142ms step_avg:97.62ms +step:1242/1670 train_time:121239ms step_avg:97.62ms +step:1243/1670 train_time:121336ms step_avg:97.62ms +step:1244/1670 train_time:121434ms step_avg:97.62ms +step:1245/1670 train_time:121533ms step_avg:97.62ms +step:1246/1670 train_time:121632ms step_avg:97.62ms +step:1247/1670 train_time:121730ms step_avg:97.62ms +step:1248/1670 train_time:121828ms step_avg:97.62ms +step:1249/1670 train_time:121926ms step_avg:97.62ms +step:1250/1670 train_time:122024ms step_avg:97.62ms +step:1250/1670 val_loss:3.3788 train_time:122121ms step_avg:97.70ms +step:1251/1670 train_time:122142ms step_avg:97.64ms +step:1252/1670 train_time:122227ms step_avg:97.63ms +step:1253/1670 train_time:122326ms step_avg:97.63ms +step:1254/1670 train_time:122424ms step_avg:97.63ms +step:1255/1670 train_time:122520ms step_avg:97.63ms +step:1256/1670 train_time:122617ms step_avg:97.63ms +step:1257/1670 train_time:122714ms step_avg:97.62ms +step:1258/1670 train_time:122810ms step_avg:97.62ms +step:1259/1670 train_time:122907ms step_avg:97.62ms +step:1260/1670 train_time:123003ms step_avg:97.62ms +step:1261/1670 train_time:123102ms step_avg:97.62ms +step:1262/1670 train_time:123204ms step_avg:97.63ms +step:1263/1670 train_time:123302ms step_avg:97.63ms +step:1264/1670 train_time:123401ms step_avg:97.63ms +step:1265/1670 train_time:123499ms step_avg:97.63ms +step:1266/1670 train_time:123596ms step_avg:97.63ms +step:1267/1670 train_time:123693ms step_avg:97.63ms +step:1268/1670 train_time:123790ms step_avg:97.63ms +step:1269/1670 train_time:123887ms step_avg:97.63ms +step:1270/1670 train_time:123984ms step_avg:97.63ms +step:1271/1670 train_time:124082ms step_avg:97.63ms +step:1272/1670 train_time:124181ms step_avg:97.63ms +step:1273/1670 train_time:124281ms step_avg:97.63ms +step:1274/1670 train_time:124553ms step_avg:97.77ms +step:1275/1670 train_time:124721ms step_avg:97.82ms +step:1276/1670 train_time:124816ms step_avg:97.82ms +step:1277/1670 train_time:124912ms step_avg:97.82ms +step:1278/1670 train_time:125009ms step_avg:97.82ms +step:1279/1670 train_time:125104ms step_avg:97.81ms +step:1280/1670 train_time:125201ms step_avg:97.81ms +step:1281/1670 train_time:125298ms step_avg:97.81ms +step:1282/1670 train_time:125395ms step_avg:97.81ms +step:1283/1670 train_time:125491ms step_avg:97.81ms +step:1284/1670 train_time:125590ms step_avg:97.81ms +step:1285/1670 train_time:125693ms step_avg:97.82ms +step:1286/1670 train_time:125792ms step_avg:97.82ms +step:1287/1670 train_time:125889ms step_avg:97.82ms +step:1288/1670 train_time:125986ms step_avg:97.81ms +step:1289/1670 train_time:126083ms step_avg:97.81ms +step:1290/1670 train_time:126180ms step_avg:97.81ms +step:1291/1670 train_time:126277ms step_avg:97.81ms +step:1292/1670 train_time:126374ms step_avg:97.81ms +step:1293/1670 train_time:126470ms step_avg:97.81ms +step:1294/1670 train_time:126568ms step_avg:97.81ms +step:1295/1670 train_time:126667ms step_avg:97.81ms +step:1296/1670 train_time:126765ms step_avg:97.81ms +step:1297/1670 train_time:126864ms step_avg:97.81ms +step:1298/1670 train_time:126962ms step_avg:97.81ms +step:1299/1670 train_time:127059ms step_avg:97.81ms +step:1300/1670 train_time:127157ms step_avg:97.81ms +step:1301/1670 train_time:127253ms step_avg:97.81ms +step:1302/1670 train_time:127350ms step_avg:97.81ms +step:1303/1670 train_time:127447ms step_avg:97.81ms +step:1304/1670 train_time:127544ms step_avg:97.81ms +step:1305/1670 train_time:127642ms step_avg:97.81ms +step:1306/1670 train_time:127741ms step_avg:97.81ms +step:1307/1670 train_time:127840ms step_avg:97.81ms +step:1308/1670 train_time:127940ms step_avg:97.81ms +step:1309/1670 train_time:128038ms step_avg:97.81ms +step:1310/1670 train_time:128136ms step_avg:97.81ms +step:1311/1670 train_time:128233ms step_avg:97.81ms +step:1312/1670 train_time:128330ms step_avg:97.81ms +step:1313/1670 train_time:128427ms step_avg:97.81ms +step:1314/1670 train_time:128524ms step_avg:97.81ms +step:1315/1670 train_time:128622ms step_avg:97.81ms +step:1316/1670 train_time:128720ms step_avg:97.81ms +step:1317/1670 train_time:128818ms step_avg:97.81ms +step:1318/1670 train_time:128917ms step_avg:97.81ms +step:1319/1670 train_time:129014ms step_avg:97.81ms +step:1320/1670 train_time:129112ms step_avg:97.81ms +step:1321/1670 train_time:129209ms step_avg:97.81ms +step:1322/1670 train_time:129306ms step_avg:97.81ms +step:1323/1670 train_time:129404ms step_avg:97.81ms +step:1324/1670 train_time:129501ms step_avg:97.81ms +step:1325/1670 train_time:129598ms step_avg:97.81ms +step:1326/1670 train_time:129696ms step_avg:97.81ms +step:1327/1670 train_time:129795ms step_avg:97.81ms +step:1328/1670 train_time:129894ms step_avg:97.81ms +step:1329/1670 train_time:129991ms step_avg:97.81ms +step:1330/1670 train_time:130090ms step_avg:97.81ms +step:1331/1670 train_time:130187ms step_avg:97.81ms +step:1332/1670 train_time:130284ms step_avg:97.81ms +step:1333/1670 train_time:130381ms step_avg:97.81ms +step:1334/1670 train_time:130478ms step_avg:97.81ms +step:1335/1670 train_time:130576ms step_avg:97.81ms +step:1336/1670 train_time:130674ms step_avg:97.81ms +step:1337/1670 train_time:130773ms step_avg:97.81ms +step:1338/1670 train_time:130871ms step_avg:97.81ms +step:1339/1670 train_time:130969ms step_avg:97.81ms +step:1340/1670 train_time:131066ms step_avg:97.81ms +step:1341/1670 train_time:131164ms step_avg:97.81ms +step:1342/1670 train_time:131262ms step_avg:97.81ms +step:1343/1670 train_time:131360ms step_avg:97.81ms +step:1344/1670 train_time:131458ms step_avg:97.81ms +step:1345/1670 train_time:131556ms step_avg:97.81ms +step:1346/1670 train_time:131655ms step_avg:97.81ms +step:1347/1670 train_time:131753ms step_avg:97.81ms +step:1348/1670 train_time:131851ms step_avg:97.81ms +step:1349/1670 train_time:131949ms step_avg:97.81ms +step:1350/1670 train_time:132046ms step_avg:97.81ms +step:1351/1670 train_time:132144ms step_avg:97.81ms +step:1352/1670 train_time:132241ms step_avg:97.81ms +step:1353/1670 train_time:132339ms step_avg:97.81ms +step:1354/1670 train_time:132437ms step_avg:97.81ms +step:1355/1670 train_time:132535ms step_avg:97.81ms +step:1356/1670 train_time:132632ms step_avg:97.81ms +step:1357/1670 train_time:132730ms step_avg:97.81ms +step:1358/1670 train_time:132827ms step_avg:97.81ms +step:1359/1670 train_time:132925ms step_avg:97.81ms +step:1360/1670 train_time:133023ms step_avg:97.81ms +step:1361/1670 train_time:133121ms step_avg:97.81ms +step:1362/1670 train_time:133220ms step_avg:97.81ms +step:1363/1670 train_time:133318ms step_avg:97.81ms +step:1364/1670 train_time:133415ms step_avg:97.81ms +step:1365/1670 train_time:133512ms step_avg:97.81ms +step:1366/1670 train_time:133609ms step_avg:97.81ms +step:1367/1670 train_time:133706ms step_avg:97.81ms +step:1368/1670 train_time:133804ms step_avg:97.81ms +step:1369/1670 train_time:133902ms step_avg:97.81ms +step:1370/1670 train_time:134000ms step_avg:97.81ms +step:1371/1670 train_time:134097ms step_avg:97.81ms +step:1372/1670 train_time:134195ms step_avg:97.81ms +step:1373/1670 train_time:134293ms step_avg:97.81ms +step:1374/1670 train_time:134390ms step_avg:97.81ms +step:1375/1670 train_time:134488ms step_avg:97.81ms +step:1375/1670 val_loss:3.3416 train_time:134585ms step_avg:97.88ms +step:1376/1670 train_time:134606ms step_avg:97.82ms +step:1377/1670 train_time:134692ms step_avg:97.82ms +step:1378/1670 train_time:134791ms step_avg:97.82ms +step:1379/1670 train_time:134889ms step_avg:97.82ms +step:1380/1670 train_time:134986ms step_avg:97.82ms +step:1381/1670 train_time:135082ms step_avg:97.81ms +step:1382/1670 train_time:135180ms step_avg:97.81ms +step:1383/1670 train_time:135276ms step_avg:97.81ms +step:1384/1670 train_time:135374ms step_avg:97.81ms +step:1385/1670 train_time:135472ms step_avg:97.81ms +step:1386/1670 train_time:135570ms step_avg:97.81ms +step:1387/1670 train_time:135669ms step_avg:97.81ms +step:1388/1670 train_time:135769ms step_avg:97.82ms +step:1389/1670 train_time:135867ms step_avg:97.82ms +step:1390/1670 train_time:135965ms step_avg:97.82ms +step:1391/1670 train_time:136062ms step_avg:97.82ms +step:1392/1670 train_time:136159ms step_avg:97.82ms +step:1393/1670 train_time:136256ms step_avg:97.82ms +step:1394/1670 train_time:136353ms step_avg:97.81ms +step:1395/1670 train_time:136451ms step_avg:97.81ms +step:1396/1670 train_time:136549ms step_avg:97.81ms +step:1397/1670 train_time:136649ms step_avg:97.82ms +step:1398/1670 train_time:136749ms step_avg:97.82ms +step:1399/1670 train_time:136848ms step_avg:97.82ms +step:1400/1670 train_time:136946ms step_avg:97.82ms +step:1401/1670 train_time:137043ms step_avg:97.82ms +step:1402/1670 train_time:137140ms step_avg:97.82ms +step:1403/1670 train_time:137236ms step_avg:97.82ms +step:1404/1670 train_time:137333ms step_avg:97.82ms +step:1405/1670 train_time:137431ms step_avg:97.82ms +step:1406/1670 train_time:137528ms step_avg:97.82ms +step:1407/1670 train_time:137626ms step_avg:97.82ms +step:1408/1670 train_time:137726ms step_avg:97.82ms +step:1409/1670 train_time:137826ms step_avg:97.82ms +step:1410/1670 train_time:137924ms step_avg:97.82ms +step:1411/1670 train_time:138022ms step_avg:97.82ms +step:1412/1670 train_time:138120ms step_avg:97.82ms +step:1413/1670 train_time:138218ms step_avg:97.82ms +step:1414/1670 train_time:138314ms step_avg:97.82ms +step:1415/1670 train_time:138412ms step_avg:97.82ms +step:1416/1670 train_time:138509ms step_avg:97.82ms +step:1417/1670 train_time:138607ms step_avg:97.82ms +step:1418/1670 train_time:138705ms step_avg:97.82ms +step:1419/1670 train_time:138804ms step_avg:97.82ms +step:1420/1670 train_time:138903ms step_avg:97.82ms +step:1421/1670 train_time:139001ms step_avg:97.82ms +step:1422/1670 train_time:139100ms step_avg:97.82ms +step:1423/1670 train_time:139199ms step_avg:97.82ms +step:1424/1670 train_time:139296ms step_avg:97.82ms +step:1425/1670 train_time:139393ms step_avg:97.82ms +step:1426/1670 train_time:139490ms step_avg:97.82ms +step:1427/1670 train_time:139588ms step_avg:97.82ms +step:1428/1670 train_time:139686ms step_avg:97.82ms +step:1429/1670 train_time:139784ms step_avg:97.82ms +step:1430/1670 train_time:139883ms step_avg:97.82ms +step:1431/1670 train_time:139980ms step_avg:97.82ms +step:1432/1670 train_time:140078ms step_avg:97.82ms +step:1433/1670 train_time:140175ms step_avg:97.82ms +step:1434/1670 train_time:140273ms step_avg:97.82ms +step:1435/1670 train_time:140370ms step_avg:97.82ms +step:1436/1670 train_time:140467ms step_avg:97.82ms +step:1437/1670 train_time:140565ms step_avg:97.82ms +step:1438/1670 train_time:140663ms step_avg:97.82ms +step:1439/1670 train_time:140761ms step_avg:97.82ms +step:1440/1670 train_time:140859ms step_avg:97.82ms +step:1441/1670 train_time:140956ms step_avg:97.82ms +step:1442/1670 train_time:141054ms step_avg:97.82ms +step:1443/1670 train_time:141152ms step_avg:97.82ms +step:1444/1670 train_time:141250ms step_avg:97.82ms +step:1445/1670 train_time:141347ms step_avg:97.82ms +step:1446/1670 train_time:141444ms step_avg:97.82ms +step:1447/1670 train_time:141542ms step_avg:97.82ms +step:1448/1670 train_time:141640ms step_avg:97.82ms +step:1449/1670 train_time:141737ms step_avg:97.82ms +step:1450/1670 train_time:141835ms step_avg:97.82ms +step:1451/1670 train_time:141932ms step_avg:97.82ms +step:1452/1670 train_time:142030ms step_avg:97.82ms +step:1453/1670 train_time:142128ms step_avg:97.82ms +step:1454/1670 train_time:142227ms step_avg:97.82ms +step:1455/1670 train_time:142325ms step_avg:97.82ms +step:1456/1670 train_time:142423ms step_avg:97.82ms +step:1457/1670 train_time:142521ms step_avg:97.82ms +step:1458/1670 train_time:142618ms step_avg:97.82ms +step:1459/1670 train_time:142716ms step_avg:97.82ms +step:1460/1670 train_time:142813ms step_avg:97.82ms +step:1461/1670 train_time:142910ms step_avg:97.82ms +step:1462/1670 train_time:143008ms step_avg:97.82ms +step:1463/1670 train_time:143106ms step_avg:97.82ms +step:1464/1670 train_time:143205ms step_avg:97.82ms +step:1465/1670 train_time:143304ms step_avg:97.82ms +step:1466/1670 train_time:143402ms step_avg:97.82ms +step:1467/1670 train_time:143500ms step_avg:97.82ms +step:1468/1670 train_time:143597ms step_avg:97.82ms +step:1469/1670 train_time:143695ms step_avg:97.82ms +step:1470/1670 train_time:143793ms step_avg:97.82ms +step:1471/1670 train_time:143890ms step_avg:97.82ms +step:1472/1670 train_time:143988ms step_avg:97.82ms +step:1473/1670 train_time:144086ms step_avg:97.82ms +step:1474/1670 train_time:144184ms step_avg:97.82ms +step:1475/1670 train_time:144282ms step_avg:97.82ms +step:1476/1670 train_time:144379ms step_avg:97.82ms +step:1477/1670 train_time:144476ms step_avg:97.82ms +step:1478/1670 train_time:144574ms step_avg:97.82ms +step:1479/1670 train_time:144671ms step_avg:97.82ms +step:1480/1670 train_time:144770ms step_avg:97.82ms +step:1481/1670 train_time:144867ms step_avg:97.82ms +step:1482/1670 train_time:144965ms step_avg:97.82ms +step:1483/1670 train_time:145063ms step_avg:97.82ms +step:1484/1670 train_time:145161ms step_avg:97.82ms +step:1485/1670 train_time:145431ms step_avg:97.93ms +step:1486/1670 train_time:145516ms step_avg:97.92ms +step:1487/1670 train_time:145612ms step_avg:97.92ms +step:1488/1670 train_time:145709ms step_avg:97.92ms +step:1489/1670 train_time:145806ms step_avg:97.92ms +step:1490/1670 train_time:145902ms step_avg:97.92ms +step:1491/1670 train_time:145998ms step_avg:97.92ms +step:1492/1670 train_time:146094ms step_avg:97.92ms +step:1493/1670 train_time:146191ms step_avg:97.92ms +step:1494/1670 train_time:146288ms step_avg:97.92ms +step:1495/1670 train_time:146392ms step_avg:97.92ms +step:1496/1670 train_time:146493ms step_avg:97.92ms +step:1497/1670 train_time:146592ms step_avg:97.92ms +step:1498/1670 train_time:146689ms step_avg:97.92ms +step:1499/1670 train_time:146787ms step_avg:97.92ms +step:1500/1670 train_time:146883ms step_avg:97.92ms +step:1500/1670 val_loss:3.3100 train_time:146979ms step_avg:97.99ms +step:1501/1670 train_time:146999ms step_avg:97.93ms +step:1502/1670 train_time:147083ms step_avg:97.92ms +step:1503/1670 train_time:147185ms step_avg:97.93ms +step:1504/1670 train_time:147282ms step_avg:97.93ms +step:1505/1670 train_time:147379ms step_avg:97.93ms +step:1506/1670 train_time:147476ms step_avg:97.93ms +step:1507/1670 train_time:147573ms step_avg:97.93ms +step:1508/1670 train_time:147671ms step_avg:97.92ms +step:1509/1670 train_time:147768ms step_avg:97.92ms +step:1510/1670 train_time:147865ms step_avg:97.92ms +step:1511/1670 train_time:147963ms step_avg:97.92ms +step:1512/1670 train_time:148063ms step_avg:97.93ms +step:1513/1670 train_time:148164ms step_avg:97.93ms +step:1514/1670 train_time:148262ms step_avg:97.93ms +step:1515/1670 train_time:148359ms step_avg:97.93ms +step:1516/1670 train_time:148456ms step_avg:97.93ms +step:1517/1670 train_time:148553ms step_avg:97.93ms +step:1518/1670 train_time:148650ms step_avg:97.92ms +step:1519/1670 train_time:148747ms step_avg:97.92ms +step:1520/1670 train_time:148845ms step_avg:97.92ms +step:1521/1670 train_time:148942ms step_avg:97.92ms +step:1522/1670 train_time:149041ms step_avg:97.92ms +step:1523/1670 train_time:149140ms step_avg:97.93ms +step:1524/1670 train_time:149239ms step_avg:97.93ms +step:1525/1670 train_time:149338ms step_avg:97.93ms +step:1526/1670 train_time:149435ms step_avg:97.93ms +step:1527/1670 train_time:149533ms step_avg:97.93ms +step:1528/1670 train_time:149631ms step_avg:97.93ms +step:1529/1670 train_time:149728ms step_avg:97.93ms +step:1530/1670 train_time:149825ms step_avg:97.92ms +step:1531/1670 train_time:149922ms step_avg:97.92ms +step:1532/1670 train_time:150021ms step_avg:97.92ms +step:1533/1670 train_time:150119ms step_avg:97.93ms +step:1534/1670 train_time:150218ms step_avg:97.93ms +step:1535/1670 train_time:150316ms step_avg:97.93ms +step:1536/1670 train_time:150413ms step_avg:97.93ms +step:1537/1670 train_time:150510ms step_avg:97.92ms +step:1538/1670 train_time:150607ms step_avg:97.92ms +step:1539/1670 train_time:150704ms step_avg:97.92ms +step:1540/1670 train_time:150801ms step_avg:97.92ms +step:1541/1670 train_time:150899ms step_avg:97.92ms +step:1542/1670 train_time:150998ms step_avg:97.92ms +step:1543/1670 train_time:151096ms step_avg:97.92ms +step:1544/1670 train_time:151195ms step_avg:97.92ms +step:1545/1670 train_time:151294ms step_avg:97.92ms +step:1546/1670 train_time:151392ms step_avg:97.92ms +step:1547/1670 train_time:151489ms step_avg:97.92ms +step:1548/1670 train_time:151587ms step_avg:97.92ms +step:1549/1670 train_time:151684ms step_avg:97.92ms +step:1550/1670 train_time:151781ms step_avg:97.92ms +step:1551/1670 train_time:151879ms step_avg:97.92ms +step:1552/1670 train_time:151975ms step_avg:97.92ms +step:1553/1670 train_time:152075ms step_avg:97.92ms +step:1554/1670 train_time:152174ms step_avg:97.92ms +step:1555/1670 train_time:152272ms step_avg:97.92ms +step:1556/1670 train_time:152370ms step_avg:97.92ms +step:1557/1670 train_time:152467ms step_avg:97.92ms +step:1558/1670 train_time:152564ms step_avg:97.92ms +step:1559/1670 train_time:152661ms step_avg:97.92ms +step:1560/1670 train_time:152758ms step_avg:97.92ms +step:1561/1670 train_time:152856ms step_avg:97.92ms +step:1562/1670 train_time:152955ms step_avg:97.92ms +step:1563/1670 train_time:153053ms step_avg:97.92ms +step:1564/1670 train_time:153150ms step_avg:97.92ms +step:1565/1670 train_time:153249ms step_avg:97.92ms +step:1566/1670 train_time:153348ms step_avg:97.92ms +step:1567/1670 train_time:153447ms step_avg:97.92ms +step:1568/1670 train_time:153544ms step_avg:97.92ms +step:1569/1670 train_time:153641ms step_avg:97.92ms +step:1570/1670 train_time:153738ms step_avg:97.92ms +step:1571/1670 train_time:153836ms step_avg:97.92ms +step:1572/1670 train_time:153934ms step_avg:97.92ms +step:1573/1670 train_time:154032ms step_avg:97.92ms +step:1574/1670 train_time:154130ms step_avg:97.92ms +step:1575/1670 train_time:154228ms step_avg:97.92ms +step:1576/1670 train_time:154326ms step_avg:97.92ms +step:1577/1670 train_time:154423ms step_avg:97.92ms +step:1578/1670 train_time:154521ms step_avg:97.92ms +step:1579/1670 train_time:154618ms step_avg:97.92ms +step:1580/1670 train_time:154715ms step_avg:97.92ms +step:1581/1670 train_time:154814ms step_avg:97.92ms +step:1582/1670 train_time:154912ms step_avg:97.92ms +step:1583/1670 train_time:155010ms step_avg:97.92ms +step:1584/1670 train_time:155108ms step_avg:97.92ms +step:1585/1670 train_time:155205ms step_avg:97.92ms +step:1586/1670 train_time:155303ms step_avg:97.92ms +step:1587/1670 train_time:155401ms step_avg:97.92ms +step:1588/1670 train_time:155498ms step_avg:97.92ms +step:1589/1670 train_time:155597ms step_avg:97.92ms +step:1590/1670 train_time:155696ms step_avg:97.92ms +step:1591/1670 train_time:155793ms step_avg:97.92ms +step:1592/1670 train_time:155891ms step_avg:97.92ms +step:1593/1670 train_time:155988ms step_avg:97.92ms +step:1594/1670 train_time:156085ms step_avg:97.92ms +step:1595/1670 train_time:156182ms step_avg:97.92ms +step:1596/1670 train_time:156280ms step_avg:97.92ms +step:1597/1670 train_time:156378ms step_avg:97.92ms +step:1598/1670 train_time:156477ms step_avg:97.92ms +step:1599/1670 train_time:156575ms step_avg:97.92ms +step:1600/1670 train_time:156673ms step_avg:97.92ms +step:1601/1670 train_time:156771ms step_avg:97.92ms +step:1602/1670 train_time:156868ms step_avg:97.92ms +step:1603/1670 train_time:156966ms step_avg:97.92ms +step:1604/1670 train_time:157062ms step_avg:97.92ms +step:1605/1670 train_time:157160ms step_avg:97.92ms +step:1606/1670 train_time:157258ms step_avg:97.92ms +step:1607/1670 train_time:157356ms step_avg:97.92ms +step:1608/1670 train_time:157454ms step_avg:97.92ms +step:1609/1670 train_time:157553ms step_avg:97.92ms +step:1610/1670 train_time:157650ms step_avg:97.92ms +step:1611/1670 train_time:157748ms step_avg:97.92ms +step:1612/1670 train_time:157846ms step_avg:97.92ms +step:1613/1670 train_time:157944ms step_avg:97.92ms +step:1614/1670 train_time:158041ms step_avg:97.92ms +step:1615/1670 train_time:158138ms step_avg:97.92ms +step:1616/1670 train_time:158236ms step_avg:97.92ms +step:1617/1670 train_time:158334ms step_avg:97.92ms +step:1618/1670 train_time:158432ms step_avg:97.92ms +step:1619/1670 train_time:158530ms step_avg:97.92ms +step:1620/1670 train_time:158627ms step_avg:97.92ms +step:1621/1670 train_time:158724ms step_avg:97.92ms +step:1622/1670 train_time:158822ms step_avg:97.92ms +step:1623/1670 train_time:158919ms step_avg:97.92ms +step:1624/1670 train_time:159017ms step_avg:97.92ms +step:1625/1670 train_time:159115ms step_avg:97.92ms +step:1625/1670 val_loss:3.2831 train_time:159213ms step_avg:97.98ms +step:1626/1670 train_time:159234ms step_avg:97.93ms +step:1627/1670 train_time:159317ms step_avg:97.92ms +step:1628/1670 train_time:159417ms step_avg:97.92ms +step:1629/1670 train_time:159515ms step_avg:97.92ms +step:1630/1670 train_time:159612ms step_avg:97.92ms +step:1631/1670 train_time:159709ms step_avg:97.92ms +step:1632/1670 train_time:159806ms step_avg:97.92ms +step:1633/1670 train_time:159902ms step_avg:97.92ms +step:1634/1670 train_time:160000ms step_avg:97.92ms +step:1635/1670 train_time:160097ms step_avg:97.92ms +step:1636/1670 train_time:160196ms step_avg:97.92ms +step:1637/1670 train_time:160296ms step_avg:97.92ms +step:1638/1670 train_time:160394ms step_avg:97.92ms +step:1639/1670 train_time:160493ms step_avg:97.92ms +step:1640/1670 train_time:160591ms step_avg:97.92ms +step:1641/1670 train_time:160688ms step_avg:97.92ms +step:1642/1670 train_time:160784ms step_avg:97.92ms +step:1643/1670 train_time:160881ms step_avg:97.92ms +step:1644/1670 train_time:160979ms step_avg:97.92ms +step:1645/1670 train_time:161077ms step_avg:97.92ms +step:1646/1670 train_time:161174ms step_avg:97.92ms +step:1647/1670 train_time:161272ms step_avg:97.92ms +step:1648/1670 train_time:161371ms step_avg:97.92ms +step:1649/1670 train_time:161469ms step_avg:97.92ms +step:1650/1670 train_time:161567ms step_avg:97.92ms +step:1651/1670 train_time:161664ms step_avg:97.92ms +step:1652/1670 train_time:161762ms step_avg:97.92ms +step:1653/1670 train_time:161859ms step_avg:97.92ms +step:1654/1670 train_time:161957ms step_avg:97.92ms +step:1655/1670 train_time:162054ms step_avg:97.92ms +step:1656/1670 train_time:162151ms step_avg:97.92ms +step:1657/1670 train_time:162250ms step_avg:97.92ms +step:1658/1670 train_time:162348ms step_avg:97.92ms +step:1659/1670 train_time:162446ms step_avg:97.92ms +step:1660/1670 train_time:162545ms step_avg:97.92ms +step:1661/1670 train_time:162644ms step_avg:97.92ms +step:1662/1670 train_time:162742ms step_avg:97.92ms +step:1663/1670 train_time:162840ms step_avg:97.92ms +step:1664/1670 train_time:162937ms step_avg:97.92ms +step:1665/1670 train_time:163034ms step_avg:97.92ms +step:1666/1670 train_time:163132ms step_avg:97.92ms +step:1667/1670 train_time:163229ms step_avg:97.92ms +step:1668/1670 train_time:163328ms step_avg:97.92ms +step:1669/1670 train_time:163426ms step_avg:97.92ms +step:1670/1670 train_time:163524ms step_avg:97.92ms +step:1670/1670 val_loss:3.2755 train_time:163621ms step_avg:97.98ms +peak memory allocated: 34217 MiB reserved: 49936 MiB diff --git a/records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt b/records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt new file mode 100644 index 000000000..a11cc06d9 --- /dev/null +++ b/records/090325_FA3/4c2f3422-1b2e-4b62-be78-f09cac5730b8.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:39:51 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 129W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 38C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 48531 C /usr/bin/python 0MiB | +| 0 N/A N/A 48532 C /usr/bin/python 0MiB | +| 0 N/A N/A 48533 C /usr/bin/python 0MiB | +| 0 N/A N/A 48534 C /usr/bin/python 0MiB | +| 0 N/A N/A 48535 C /usr/bin/python 0MiB | +| 0 N/A N/A 48536 C /usr/bin/python 0MiB | +| 0 N/A N/A 48537 C /usr/bin/python 0MiB | +| 0 N/A N/A 48538 C /usr/bin/python 0MiB | +| 1 N/A N/A 48532 C /usr/bin/python 0MiB | +| 2 N/A N/A 48533 C /usr/bin/python 0MiB | +| 3 N/A N/A 48534 C /usr/bin/python 0MiB | +| 4 N/A N/A 48535 C /usr/bin/python 0MiB | +| 5 N/A N/A 48536 C /usr/bin/python 0MiB | +| 6 N/A N/A 48537 C /usr/bin/python 0MiB | +| 7 N/A N/A 48538 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:407ms step_avg:407.50ms +step:2/1670 train_time:427ms step_avg:213.68ms +step:3/1670 train_time:501ms step_avg:167.04ms +step:4/1670 train_time:594ms step_avg:148.59ms +step:5/1670 train_time:689ms step_avg:137.71ms +step:6/1670 train_time:784ms step_avg:130.59ms +step:7/1670 train_time:878ms step_avg:125.44ms +step:8/1670 train_time:973ms step_avg:121.64ms +step:9/1670 train_time:1068ms step_avg:118.69ms +step:10/1670 train_time:1163ms step_avg:116.33ms +step:11/1670 train_time:1258ms step_avg:114.39ms +step:12/1670 train_time:1355ms step_avg:112.93ms +step:13/1670 train_time:1454ms step_avg:111.86ms +step:14/1670 train_time:1551ms step_avg:110.78ms +step:15/1670 train_time:1648ms step_avg:109.84ms +step:16/1670 train_time:1742ms step_avg:108.90ms +step:17/1670 train_time:1838ms step_avg:108.10ms +step:18/1670 train_time:1933ms step_avg:107.37ms +step:19/1670 train_time:2028ms step_avg:106.73ms +step:20/1670 train_time:2123ms step_avg:106.16ms +step:21/1670 train_time:2218ms step_avg:105.64ms +step:22/1670 train_time:2314ms step_avg:105.20ms +step:23/1670 train_time:2412ms step_avg:104.86ms +step:24/1670 train_time:2508ms step_avg:104.52ms +step:25/1670 train_time:2606ms step_avg:104.22ms +step:26/1670 train_time:2702ms step_avg:103.94ms +step:27/1670 train_time:2799ms step_avg:103.66ms +step:28/1670 train_time:2894ms step_avg:103.34ms +step:29/1670 train_time:2989ms step_avg:103.06ms +step:30/1670 train_time:3084ms step_avg:102.79ms +step:31/1670 train_time:3180ms step_avg:102.59ms +step:32/1670 train_time:3276ms step_avg:102.37ms +step:33/1670 train_time:3372ms step_avg:102.18ms +step:34/1670 train_time:3469ms step_avg:102.02ms +step:35/1670 train_time:3566ms step_avg:101.87ms +step:36/1670 train_time:3662ms step_avg:101.73ms +step:37/1670 train_time:3758ms step_avg:101.58ms +step:38/1670 train_time:3854ms step_avg:101.41ms +step:39/1670 train_time:3950ms step_avg:101.28ms +step:40/1670 train_time:4046ms step_avg:101.14ms +step:41/1670 train_time:4141ms step_avg:101.00ms +step:42/1670 train_time:4237ms step_avg:100.87ms +step:43/1670 train_time:4332ms step_avg:100.74ms +step:44/1670 train_time:4428ms step_avg:100.64ms +step:45/1670 train_time:4525ms step_avg:100.55ms +step:46/1670 train_time:4620ms step_avg:100.44ms +step:47/1670 train_time:4717ms step_avg:100.35ms +step:48/1670 train_time:4812ms step_avg:100.25ms +step:49/1670 train_time:4908ms step_avg:100.17ms +step:50/1670 train_time:5004ms step_avg:100.09ms +step:51/1670 train_time:5100ms step_avg:100.00ms +step:52/1670 train_time:5195ms step_avg:99.90ms +step:53/1670 train_time:5291ms step_avg:99.83ms +step:54/1670 train_time:5387ms step_avg:99.77ms +step:55/1670 train_time:5483ms step_avg:99.69ms +step:56/1670 train_time:5579ms step_avg:99.62ms +step:57/1670 train_time:5674ms step_avg:99.55ms +step:58/1670 train_time:5770ms step_avg:99.47ms +step:59/1670 train_time:5865ms step_avg:99.41ms +step:60/1670 train_time:5962ms step_avg:99.36ms +step:61/1670 train_time:6057ms step_avg:99.30ms +step:62/1670 train_time:6154ms step_avg:99.25ms +step:63/1670 train_time:6249ms step_avg:99.20ms +step:64/1670 train_time:6345ms step_avg:99.14ms +step:65/1670 train_time:6442ms step_avg:99.10ms +step:66/1670 train_time:6538ms step_avg:99.06ms +step:67/1670 train_time:6634ms step_avg:99.01ms +step:68/1670 train_time:6730ms step_avg:98.97ms +step:69/1670 train_time:6826ms step_avg:98.92ms +step:70/1670 train_time:6922ms step_avg:98.88ms +step:71/1670 train_time:7018ms step_avg:98.84ms +step:72/1670 train_time:7113ms step_avg:98.79ms +step:73/1670 train_time:7209ms step_avg:98.75ms +step:74/1670 train_time:7304ms step_avg:98.70ms +step:75/1670 train_time:7400ms step_avg:98.67ms +step:76/1670 train_time:7495ms step_avg:98.62ms +step:77/1670 train_time:7592ms step_avg:98.60ms +step:78/1670 train_time:7688ms step_avg:98.57ms +step:79/1670 train_time:7784ms step_avg:98.53ms +step:80/1670 train_time:7879ms step_avg:98.49ms +step:81/1670 train_time:7975ms step_avg:98.46ms +step:82/1670 train_time:8071ms step_avg:98.43ms +step:83/1670 train_time:8167ms step_avg:98.40ms +step:84/1670 train_time:8263ms step_avg:98.37ms +step:85/1670 train_time:8358ms step_avg:98.33ms +step:86/1670 train_time:8454ms step_avg:98.30ms +step:87/1670 train_time:8549ms step_avg:98.27ms +step:88/1670 train_time:8645ms step_avg:98.24ms +step:89/1670 train_time:8741ms step_avg:98.21ms +step:90/1670 train_time:8836ms step_avg:98.18ms +step:91/1670 train_time:8932ms step_avg:98.15ms +step:92/1670 train_time:9028ms step_avg:98.13ms +step:93/1670 train_time:9125ms step_avg:98.11ms +step:94/1670 train_time:9220ms step_avg:98.08ms +step:95/1670 train_time:9315ms step_avg:98.06ms +step:96/1670 train_time:9411ms step_avg:98.03ms +step:97/1670 train_time:9507ms step_avg:98.01ms +step:98/1670 train_time:9603ms step_avg:97.99ms +step:99/1670 train_time:9699ms step_avg:97.97ms +step:100/1670 train_time:9795ms step_avg:97.95ms +step:101/1670 train_time:9892ms step_avg:97.94ms +step:102/1670 train_time:9987ms step_avg:97.91ms +step:103/1670 train_time:10082ms step_avg:97.88ms +step:104/1670 train_time:10177ms step_avg:97.86ms +step:105/1670 train_time:10273ms step_avg:97.84ms +step:106/1670 train_time:10369ms step_avg:97.83ms +step:107/1670 train_time:10465ms step_avg:97.81ms +step:108/1670 train_time:10561ms step_avg:97.79ms +step:109/1670 train_time:10656ms step_avg:97.76ms +step:110/1670 train_time:10752ms step_avg:97.75ms +step:111/1670 train_time:10848ms step_avg:97.73ms +step:112/1670 train_time:10944ms step_avg:97.72ms +step:113/1670 train_time:11039ms step_avg:97.69ms +step:114/1670 train_time:11134ms step_avg:97.67ms +step:115/1670 train_time:11230ms step_avg:97.65ms +step:116/1670 train_time:11326ms step_avg:97.64ms +step:117/1670 train_time:11422ms step_avg:97.62ms +step:118/1670 train_time:11518ms step_avg:97.61ms +step:119/1670 train_time:11614ms step_avg:97.59ms +step:120/1670 train_time:11709ms step_avg:97.58ms +step:121/1670 train_time:11805ms step_avg:97.57ms +step:122/1670 train_time:11901ms step_avg:97.55ms +step:123/1670 train_time:11997ms step_avg:97.53ms +step:124/1670 train_time:12092ms step_avg:97.52ms +step:125/1670 train_time:12189ms step_avg:97.51ms +step:125/1670 val_loss:4.2887 train_time:12284ms step_avg:98.27ms +step:126/1670 train_time:12306ms step_avg:97.67ms +step:127/1670 train_time:12388ms step_avg:97.54ms +step:128/1670 train_time:12494ms step_avg:97.61ms +step:129/1670 train_time:12591ms step_avg:97.61ms +step:130/1670 train_time:12686ms step_avg:97.59ms +step:131/1670 train_time:12781ms step_avg:97.56ms +step:132/1670 train_time:12876ms step_avg:97.55ms +step:133/1670 train_time:12971ms step_avg:97.53ms +step:134/1670 train_time:13066ms step_avg:97.51ms +step:135/1670 train_time:13160ms step_avg:97.48ms +step:136/1670 train_time:13255ms step_avg:97.46ms +step:137/1670 train_time:13353ms step_avg:97.47ms +step:138/1670 train_time:13452ms step_avg:97.47ms +step:139/1670 train_time:13550ms step_avg:97.48ms +step:140/1670 train_time:13646ms step_avg:97.47ms +step:141/1670 train_time:13742ms step_avg:97.46ms +step:142/1670 train_time:13838ms step_avg:97.45ms +step:143/1670 train_time:13933ms step_avg:97.43ms +step:144/1670 train_time:14028ms step_avg:97.42ms +step:145/1670 train_time:14123ms step_avg:97.40ms +step:146/1670 train_time:14218ms step_avg:97.38ms +step:147/1670 train_time:14314ms step_avg:97.37ms +step:148/1670 train_time:14410ms step_avg:97.37ms +step:149/1670 train_time:14507ms step_avg:97.36ms +step:150/1670 train_time:14603ms step_avg:97.35ms +step:151/1670 train_time:14698ms step_avg:97.34ms +step:152/1670 train_time:14794ms step_avg:97.33ms +step:153/1670 train_time:14890ms step_avg:97.32ms +step:154/1670 train_time:14985ms step_avg:97.31ms +step:155/1670 train_time:15080ms step_avg:97.29ms +step:156/1670 train_time:15175ms step_avg:97.28ms +step:157/1670 train_time:15271ms step_avg:97.27ms +step:158/1670 train_time:15367ms step_avg:97.26ms +step:159/1670 train_time:15463ms step_avg:97.25ms +step:160/1670 train_time:15559ms step_avg:97.25ms +step:161/1670 train_time:15656ms step_avg:97.24ms +step:162/1670 train_time:15753ms step_avg:97.24ms +step:163/1670 train_time:15849ms step_avg:97.23ms +step:164/1670 train_time:15945ms step_avg:97.23ms +step:165/1670 train_time:16040ms step_avg:97.21ms +step:166/1670 train_time:16136ms step_avg:97.20ms +step:167/1670 train_time:16232ms step_avg:97.19ms +step:168/1670 train_time:16327ms step_avg:97.18ms +step:169/1670 train_time:16422ms step_avg:97.17ms +step:170/1670 train_time:16518ms step_avg:97.17ms +step:171/1670 train_time:16615ms step_avg:97.16ms +step:172/1670 train_time:16711ms step_avg:97.16ms +step:173/1670 train_time:16807ms step_avg:97.15ms +step:174/1670 train_time:16902ms step_avg:97.14ms +step:175/1670 train_time:16998ms step_avg:97.13ms +step:176/1670 train_time:17093ms step_avg:97.12ms +step:177/1670 train_time:17189ms step_avg:97.11ms +step:178/1670 train_time:17284ms step_avg:97.10ms +step:179/1670 train_time:17379ms step_avg:97.09ms +step:180/1670 train_time:17476ms step_avg:97.09ms +step:181/1670 train_time:17573ms step_avg:97.09ms +step:182/1670 train_time:17669ms step_avg:97.08ms +step:183/1670 train_time:17764ms step_avg:97.07ms +step:184/1670 train_time:17859ms step_avg:97.06ms +step:185/1670 train_time:17955ms step_avg:97.05ms +step:186/1670 train_time:18050ms step_avg:97.04ms +step:187/1670 train_time:18146ms step_avg:97.04ms +step:188/1670 train_time:18241ms step_avg:97.03ms +step:189/1670 train_time:18337ms step_avg:97.02ms +step:190/1670 train_time:18432ms step_avg:97.01ms +step:191/1670 train_time:18528ms step_avg:97.00ms +step:192/1670 train_time:18623ms step_avg:96.99ms +step:193/1670 train_time:18719ms step_avg:96.99ms +step:194/1670 train_time:18816ms step_avg:96.99ms +step:195/1670 train_time:18911ms step_avg:96.98ms +step:196/1670 train_time:19006ms step_avg:96.97ms +step:197/1670 train_time:19101ms step_avg:96.96ms +step:198/1670 train_time:19197ms step_avg:96.96ms +step:199/1670 train_time:19293ms step_avg:96.95ms +step:200/1670 train_time:19388ms step_avg:96.94ms +step:201/1670 train_time:19483ms step_avg:96.93ms +step:202/1670 train_time:19579ms step_avg:96.93ms +step:203/1670 train_time:19675ms step_avg:96.92ms +step:204/1670 train_time:19771ms step_avg:96.92ms +step:205/1670 train_time:19866ms step_avg:96.91ms +step:206/1670 train_time:19962ms step_avg:96.90ms +step:207/1670 train_time:20057ms step_avg:96.89ms +step:208/1670 train_time:20153ms step_avg:96.89ms +step:209/1670 train_time:20249ms step_avg:96.89ms +step:210/1670 train_time:20345ms step_avg:96.88ms +step:211/1670 train_time:20440ms step_avg:96.87ms +step:212/1670 train_time:20536ms step_avg:96.87ms +step:213/1670 train_time:20868ms step_avg:97.97ms +step:214/1670 train_time:20942ms step_avg:97.86ms +step:215/1670 train_time:21036ms step_avg:97.84ms +step:216/1670 train_time:21132ms step_avg:97.83ms +step:217/1670 train_time:21226ms step_avg:97.81ms +step:218/1670 train_time:21320ms step_avg:97.80ms +step:219/1670 train_time:21415ms step_avg:97.78ms +step:220/1670 train_time:21509ms step_avg:97.77ms +step:221/1670 train_time:21604ms step_avg:97.75ms +step:222/1670 train_time:21698ms step_avg:97.74ms +step:223/1670 train_time:21794ms step_avg:97.73ms +step:224/1670 train_time:21897ms step_avg:97.75ms +step:225/1670 train_time:21995ms step_avg:97.75ms +step:226/1670 train_time:22091ms step_avg:97.75ms +step:227/1670 train_time:22186ms step_avg:97.74ms +step:228/1670 train_time:22281ms step_avg:97.72ms +step:229/1670 train_time:22377ms step_avg:97.71ms +step:230/1670 train_time:22471ms step_avg:97.70ms +step:231/1670 train_time:22566ms step_avg:97.69ms +step:232/1670 train_time:22660ms step_avg:97.67ms +step:233/1670 train_time:22755ms step_avg:97.66ms +step:234/1670 train_time:22852ms step_avg:97.66ms +step:235/1670 train_time:22950ms step_avg:97.66ms +step:236/1670 train_time:23046ms step_avg:97.65ms +step:237/1670 train_time:23141ms step_avg:97.64ms +step:238/1670 train_time:23236ms step_avg:97.63ms +step:239/1670 train_time:23332ms step_avg:97.62ms +step:240/1670 train_time:23428ms step_avg:97.62ms +step:241/1670 train_time:23523ms step_avg:97.61ms +step:242/1670 train_time:23617ms step_avg:97.59ms +step:243/1670 train_time:23713ms step_avg:97.58ms +step:244/1670 train_time:23809ms step_avg:97.58ms +step:245/1670 train_time:23905ms step_avg:97.57ms +step:246/1670 train_time:24000ms step_avg:97.56ms +step:247/1670 train_time:24097ms step_avg:97.56ms +step:248/1670 train_time:24194ms step_avg:97.55ms +step:249/1670 train_time:24289ms step_avg:97.55ms +step:250/1670 train_time:24384ms step_avg:97.54ms +step:250/1670 val_loss:3.9604 train_time:24478ms step_avg:97.91ms +step:251/1670 train_time:24499ms step_avg:97.61ms +step:252/1670 train_time:24580ms step_avg:97.54ms +step:253/1670 train_time:24678ms step_avg:97.54ms +step:254/1670 train_time:24773ms step_avg:97.53ms +step:255/1670 train_time:24868ms step_avg:97.52ms +step:256/1670 train_time:24963ms step_avg:97.51ms +step:257/1670 train_time:25058ms step_avg:97.50ms +step:258/1670 train_time:25152ms step_avg:97.49ms +step:259/1670 train_time:25248ms step_avg:97.48ms +step:260/1670 train_time:25342ms step_avg:97.47ms +step:261/1670 train_time:25438ms step_avg:97.46ms +step:262/1670 train_time:25535ms step_avg:97.46ms +step:263/1670 train_time:25633ms step_avg:97.46ms +step:264/1670 train_time:25730ms step_avg:97.46ms +step:265/1670 train_time:25826ms step_avg:97.46ms +step:266/1670 train_time:25920ms step_avg:97.45ms +step:267/1670 train_time:26015ms step_avg:97.43ms +step:268/1670 train_time:26110ms step_avg:97.42ms +step:269/1670 train_time:26204ms step_avg:97.41ms +step:270/1670 train_time:26299ms step_avg:97.40ms +step:271/1670 train_time:26394ms step_avg:97.40ms +step:272/1670 train_time:26491ms step_avg:97.39ms +step:273/1670 train_time:26589ms step_avg:97.39ms +step:274/1670 train_time:26685ms step_avg:97.39ms +step:275/1670 train_time:26781ms step_avg:97.39ms +step:276/1670 train_time:26876ms step_avg:97.38ms +step:277/1670 train_time:26971ms step_avg:97.37ms +step:278/1670 train_time:27067ms step_avg:97.36ms +step:279/1670 train_time:27163ms step_avg:97.36ms +step:280/1670 train_time:27257ms step_avg:97.35ms +step:281/1670 train_time:27352ms step_avg:97.34ms +step:282/1670 train_time:27449ms step_avg:97.34ms +step:283/1670 train_time:27545ms step_avg:97.33ms +step:284/1670 train_time:27641ms step_avg:97.33ms +step:285/1670 train_time:27737ms step_avg:97.32ms +step:286/1670 train_time:27832ms step_avg:97.32ms +step:287/1670 train_time:27929ms step_avg:97.31ms +step:288/1670 train_time:28024ms step_avg:97.30ms +step:289/1670 train_time:28119ms step_avg:97.30ms +step:290/1670 train_time:28213ms step_avg:97.29ms +step:291/1670 train_time:28308ms step_avg:97.28ms +step:292/1670 train_time:28404ms step_avg:97.27ms +step:293/1670 train_time:28500ms step_avg:97.27ms +step:294/1670 train_time:28596ms step_avg:97.26ms +step:295/1670 train_time:28692ms step_avg:97.26ms +step:296/1670 train_time:28788ms step_avg:97.26ms +step:297/1670 train_time:28884ms step_avg:97.25ms +step:298/1670 train_time:28979ms step_avg:97.25ms +step:299/1670 train_time:29075ms step_avg:97.24ms +step:300/1670 train_time:29170ms step_avg:97.23ms +step:301/1670 train_time:29265ms step_avg:97.22ms +step:302/1670 train_time:29361ms step_avg:97.22ms +step:303/1670 train_time:29456ms step_avg:97.22ms +step:304/1670 train_time:29552ms step_avg:97.21ms +step:305/1670 train_time:29648ms step_avg:97.21ms +step:306/1670 train_time:29744ms step_avg:97.20ms +step:307/1670 train_time:29839ms step_avg:97.20ms +step:308/1670 train_time:29935ms step_avg:97.19ms +step:309/1670 train_time:30031ms step_avg:97.19ms +step:310/1670 train_time:30126ms step_avg:97.18ms +step:311/1670 train_time:30222ms step_avg:97.18ms +step:312/1670 train_time:30317ms step_avg:97.17ms +step:313/1670 train_time:30412ms step_avg:97.16ms +step:314/1670 train_time:30507ms step_avg:97.16ms +step:315/1670 train_time:30603ms step_avg:97.15ms +step:316/1670 train_time:30699ms step_avg:97.15ms +step:317/1670 train_time:30794ms step_avg:97.14ms +step:318/1670 train_time:30889ms step_avg:97.14ms +step:319/1670 train_time:30986ms step_avg:97.13ms +step:320/1670 train_time:31082ms step_avg:97.13ms +step:321/1670 train_time:31177ms step_avg:97.13ms +step:322/1670 train_time:31273ms step_avg:97.12ms +step:323/1670 train_time:31367ms step_avg:97.11ms +step:324/1670 train_time:31463ms step_avg:97.11ms +step:325/1670 train_time:31558ms step_avg:97.10ms +step:326/1670 train_time:31653ms step_avg:97.10ms +step:327/1670 train_time:31750ms step_avg:97.09ms +step:328/1670 train_time:31847ms step_avg:97.09ms +step:329/1670 train_time:31943ms step_avg:97.09ms +step:330/1670 train_time:32039ms step_avg:97.09ms +step:331/1670 train_time:32134ms step_avg:97.08ms +step:332/1670 train_time:32230ms step_avg:97.08ms +step:333/1670 train_time:32326ms step_avg:97.08ms +step:334/1670 train_time:32421ms step_avg:97.07ms +step:335/1670 train_time:32516ms step_avg:97.06ms +step:336/1670 train_time:32611ms step_avg:97.06ms +step:337/1670 train_time:32707ms step_avg:97.05ms +step:338/1670 train_time:32803ms step_avg:97.05ms +step:339/1670 train_time:32899ms step_avg:97.05ms +step:340/1670 train_time:32994ms step_avg:97.04ms +step:341/1670 train_time:33089ms step_avg:97.04ms +step:342/1670 train_time:33185ms step_avg:97.03ms +step:343/1670 train_time:33281ms step_avg:97.03ms +step:344/1670 train_time:33376ms step_avg:97.02ms +step:345/1670 train_time:33471ms step_avg:97.02ms +step:346/1670 train_time:33567ms step_avg:97.01ms +step:347/1670 train_time:33663ms step_avg:97.01ms +step:348/1670 train_time:33758ms step_avg:97.00ms +step:349/1670 train_time:33854ms step_avg:97.00ms +step:350/1670 train_time:33950ms step_avg:97.00ms +step:351/1670 train_time:34045ms step_avg:96.99ms +step:352/1670 train_time:34141ms step_avg:96.99ms +step:353/1670 train_time:34236ms step_avg:96.99ms +step:354/1670 train_time:34331ms step_avg:96.98ms +step:355/1670 train_time:34427ms step_avg:96.98ms +step:356/1670 train_time:34523ms step_avg:96.97ms +step:357/1670 train_time:34618ms step_avg:96.97ms +step:358/1670 train_time:34713ms step_avg:96.96ms +step:359/1670 train_time:34809ms step_avg:96.96ms +step:360/1670 train_time:34905ms step_avg:96.96ms +step:361/1670 train_time:35001ms step_avg:96.95ms +step:362/1670 train_time:35096ms step_avg:96.95ms +step:363/1670 train_time:35192ms step_avg:96.95ms +step:364/1670 train_time:35288ms step_avg:96.94ms +step:365/1670 train_time:35384ms step_avg:96.94ms +step:366/1670 train_time:35480ms step_avg:96.94ms +step:367/1670 train_time:35575ms step_avg:96.93ms +step:368/1670 train_time:35671ms step_avg:96.93ms +step:369/1670 train_time:35766ms step_avg:96.93ms +step:370/1670 train_time:35862ms step_avg:96.92ms +step:371/1670 train_time:35957ms step_avg:96.92ms +step:372/1670 train_time:36052ms step_avg:96.91ms +step:373/1670 train_time:36149ms step_avg:96.91ms +step:374/1670 train_time:36245ms step_avg:96.91ms +step:375/1670 train_time:36341ms step_avg:96.91ms +step:375/1670 val_loss:3.8122 train_time:36435ms step_avg:97.16ms +step:376/1670 train_time:36456ms step_avg:96.96ms +step:377/1670 train_time:36540ms step_avg:96.92ms +step:378/1670 train_time:36639ms step_avg:96.93ms +step:379/1670 train_time:36736ms step_avg:96.93ms +step:380/1670 train_time:36831ms step_avg:96.92ms +step:381/1670 train_time:36926ms step_avg:96.92ms +step:382/1670 train_time:37021ms step_avg:96.91ms +step:383/1670 train_time:37116ms step_avg:96.91ms +step:384/1670 train_time:37210ms step_avg:96.90ms +step:385/1670 train_time:37304ms step_avg:96.89ms +step:386/1670 train_time:37400ms step_avg:96.89ms +step:387/1670 train_time:37497ms step_avg:96.89ms +step:388/1670 train_time:37595ms step_avg:96.89ms +step:389/1670 train_time:37691ms step_avg:96.89ms +step:390/1670 train_time:37786ms step_avg:96.89ms +step:391/1670 train_time:37882ms step_avg:96.89ms +step:392/1670 train_time:37978ms step_avg:96.88ms +step:393/1670 train_time:38073ms step_avg:96.88ms +step:394/1670 train_time:38167ms step_avg:96.87ms +step:395/1670 train_time:38262ms step_avg:96.87ms +step:396/1670 train_time:38357ms step_avg:96.86ms +step:397/1670 train_time:38453ms step_avg:96.86ms +step:398/1670 train_time:38549ms step_avg:96.86ms +step:399/1670 train_time:38646ms step_avg:96.86ms +step:400/1670 train_time:38743ms step_avg:96.86ms +step:401/1670 train_time:38840ms step_avg:96.86ms +step:402/1670 train_time:38936ms step_avg:96.85ms +step:403/1670 train_time:39031ms step_avg:96.85ms +step:404/1670 train_time:39126ms step_avg:96.85ms +step:405/1670 train_time:39221ms step_avg:96.84ms +step:406/1670 train_time:39317ms step_avg:96.84ms +step:407/1670 train_time:39412ms step_avg:96.84ms +step:408/1670 train_time:39508ms step_avg:96.83ms +step:409/1670 train_time:39605ms step_avg:96.83ms +step:410/1670 train_time:39701ms step_avg:96.83ms +step:411/1670 train_time:39797ms step_avg:96.83ms +step:412/1670 train_time:39893ms step_avg:96.83ms +step:413/1670 train_time:39988ms step_avg:96.82ms +step:414/1670 train_time:40083ms step_avg:96.82ms +step:415/1670 train_time:40178ms step_avg:96.82ms +step:416/1670 train_time:40274ms step_avg:96.81ms +step:417/1670 train_time:40370ms step_avg:96.81ms +step:418/1670 train_time:40465ms step_avg:96.81ms +step:419/1670 train_time:40561ms step_avg:96.80ms +step:420/1670 train_time:40658ms step_avg:96.80ms +step:421/1670 train_time:40753ms step_avg:96.80ms +step:422/1670 train_time:40849ms step_avg:96.80ms +step:423/1670 train_time:40945ms step_avg:96.80ms +step:424/1670 train_time:41041ms step_avg:96.80ms +step:425/1670 train_time:41337ms step_avg:97.26ms +step:426/1670 train_time:41450ms step_avg:97.30ms +step:427/1670 train_time:41543ms step_avg:97.29ms +step:428/1670 train_time:41638ms step_avg:97.29ms +step:429/1670 train_time:41733ms step_avg:97.28ms +step:430/1670 train_time:41827ms step_avg:97.27ms +step:431/1670 train_time:41922ms step_avg:97.27ms +step:432/1670 train_time:42017ms step_avg:97.26ms +step:433/1670 train_time:42112ms step_avg:97.26ms +step:434/1670 train_time:42206ms step_avg:97.25ms +step:435/1670 train_time:42303ms step_avg:97.25ms +step:436/1670 train_time:42404ms step_avg:97.26ms +step:437/1670 train_time:42504ms step_avg:97.26ms +step:438/1670 train_time:42600ms step_avg:97.26ms +step:439/1670 train_time:42695ms step_avg:97.25ms +step:440/1670 train_time:42789ms step_avg:97.25ms +step:441/1670 train_time:42884ms step_avg:97.24ms +step:442/1670 train_time:42979ms step_avg:97.24ms +step:443/1670 train_time:43073ms step_avg:97.23ms +step:444/1670 train_time:43168ms step_avg:97.22ms +step:445/1670 train_time:43264ms step_avg:97.22ms +step:446/1670 train_time:43362ms step_avg:97.22ms +step:447/1670 train_time:43460ms step_avg:97.23ms +step:448/1670 train_time:43557ms step_avg:97.23ms +step:449/1670 train_time:43653ms step_avg:97.22ms +step:450/1670 train_time:43748ms step_avg:97.22ms +step:451/1670 train_time:43844ms step_avg:97.22ms +step:452/1670 train_time:43940ms step_avg:97.21ms +step:453/1670 train_time:44034ms step_avg:97.21ms +step:454/1670 train_time:44129ms step_avg:97.20ms +step:455/1670 train_time:44224ms step_avg:97.19ms +step:456/1670 train_time:44320ms step_avg:97.19ms +step:457/1670 train_time:44417ms step_avg:97.19ms +step:458/1670 train_time:44513ms step_avg:97.19ms +step:459/1670 train_time:44609ms step_avg:97.19ms +step:460/1670 train_time:44704ms step_avg:97.18ms +step:461/1670 train_time:44800ms step_avg:97.18ms +step:462/1670 train_time:44895ms step_avg:97.18ms +step:463/1670 train_time:44990ms step_avg:97.17ms +step:464/1670 train_time:45084ms step_avg:97.16ms +step:465/1670 train_time:45179ms step_avg:97.16ms +step:466/1670 train_time:45274ms step_avg:97.16ms +step:467/1670 train_time:45370ms step_avg:97.15ms +step:468/1670 train_time:45466ms step_avg:97.15ms +step:469/1670 train_time:45563ms step_avg:97.15ms +step:470/1670 train_time:45660ms step_avg:97.15ms +step:471/1670 train_time:45756ms step_avg:97.15ms +step:472/1670 train_time:45851ms step_avg:97.14ms +step:473/1670 train_time:45946ms step_avg:97.14ms +step:474/1670 train_time:46041ms step_avg:97.13ms +step:475/1670 train_time:46136ms step_avg:97.13ms +step:476/1670 train_time:46231ms step_avg:97.12ms +step:477/1670 train_time:46326ms step_avg:97.12ms +step:478/1670 train_time:46422ms step_avg:97.12ms +step:479/1670 train_time:46519ms step_avg:97.12ms +step:480/1670 train_time:46615ms step_avg:97.11ms +step:481/1670 train_time:46710ms step_avg:97.11ms +step:482/1670 train_time:46806ms step_avg:97.11ms +step:483/1670 train_time:46901ms step_avg:97.10ms +step:484/1670 train_time:46997ms step_avg:97.10ms +step:485/1670 train_time:47092ms step_avg:97.10ms +step:486/1670 train_time:47186ms step_avg:97.09ms +step:487/1670 train_time:47282ms step_avg:97.09ms +step:488/1670 train_time:47378ms step_avg:97.09ms +step:489/1670 train_time:47473ms step_avg:97.08ms +step:490/1670 train_time:47568ms step_avg:97.08ms +step:491/1670 train_time:47665ms step_avg:97.08ms +step:492/1670 train_time:47761ms step_avg:97.07ms +step:493/1670 train_time:47856ms step_avg:97.07ms +step:494/1670 train_time:47952ms step_avg:97.07ms +step:495/1670 train_time:48046ms step_avg:97.06ms +step:496/1670 train_time:48142ms step_avg:97.06ms +step:497/1670 train_time:48238ms step_avg:97.06ms +step:498/1670 train_time:48334ms step_avg:97.06ms +step:499/1670 train_time:48429ms step_avg:97.05ms +step:500/1670 train_time:48524ms step_avg:97.05ms +step:500/1670 val_loss:3.7107 train_time:48620ms step_avg:97.24ms +step:501/1670 train_time:48640ms step_avg:97.09ms +step:502/1670 train_time:48722ms step_avg:97.06ms +step:503/1670 train_time:48823ms step_avg:97.06ms +step:504/1670 train_time:48919ms step_avg:97.06ms +step:505/1670 train_time:49014ms step_avg:97.06ms +step:506/1670 train_time:49109ms step_avg:97.05ms +step:507/1670 train_time:49204ms step_avg:97.05ms +step:508/1670 train_time:49298ms step_avg:97.04ms +step:509/1670 train_time:49393ms step_avg:97.04ms +step:510/1670 train_time:49489ms step_avg:97.04ms +step:511/1670 train_time:49583ms step_avg:97.03ms +step:512/1670 train_time:49679ms step_avg:97.03ms +step:513/1670 train_time:49777ms step_avg:97.03ms +step:514/1670 train_time:49876ms step_avg:97.03ms +step:515/1670 train_time:49973ms step_avg:97.03ms +step:516/1670 train_time:50068ms step_avg:97.03ms +step:517/1670 train_time:50164ms step_avg:97.03ms +step:518/1670 train_time:50259ms step_avg:97.02ms +step:519/1670 train_time:50353ms step_avg:97.02ms +step:520/1670 train_time:50448ms step_avg:97.02ms +step:521/1670 train_time:50544ms step_avg:97.01ms +step:522/1670 train_time:50639ms step_avg:97.01ms +step:523/1670 train_time:50735ms step_avg:97.01ms +step:524/1670 train_time:50832ms step_avg:97.01ms +step:525/1670 train_time:50929ms step_avg:97.01ms +step:526/1670 train_time:51025ms step_avg:97.01ms +step:527/1670 train_time:51119ms step_avg:97.00ms +step:528/1670 train_time:51214ms step_avg:97.00ms +step:529/1670 train_time:51310ms step_avg:96.99ms +step:530/1670 train_time:51406ms step_avg:96.99ms +step:531/1670 train_time:51500ms step_avg:96.99ms +step:532/1670 train_time:51596ms step_avg:96.98ms +step:533/1670 train_time:51691ms step_avg:96.98ms +step:534/1670 train_time:51788ms step_avg:96.98ms +step:535/1670 train_time:51884ms step_avg:96.98ms +step:536/1670 train_time:51980ms step_avg:96.98ms +step:537/1670 train_time:52076ms step_avg:96.98ms +step:538/1670 train_time:52172ms step_avg:96.97ms +step:539/1670 train_time:52267ms step_avg:96.97ms +step:540/1670 train_time:52362ms step_avg:96.97ms +step:541/1670 train_time:52456ms step_avg:96.96ms +step:542/1670 train_time:52552ms step_avg:96.96ms +step:543/1670 train_time:52648ms step_avg:96.96ms +step:544/1670 train_time:52744ms step_avg:96.96ms +step:545/1670 train_time:52840ms step_avg:96.95ms +step:546/1670 train_time:52936ms step_avg:96.95ms +step:547/1670 train_time:53031ms step_avg:96.95ms +step:548/1670 train_time:53127ms step_avg:96.95ms +step:549/1670 train_time:53223ms step_avg:96.95ms +step:550/1670 train_time:53318ms step_avg:96.94ms +step:551/1670 train_time:53413ms step_avg:96.94ms +step:552/1670 train_time:53509ms step_avg:96.94ms +step:553/1670 train_time:53604ms step_avg:96.93ms +step:554/1670 train_time:53699ms step_avg:96.93ms +step:555/1670 train_time:53796ms step_avg:96.93ms +step:556/1670 train_time:53891ms step_avg:96.93ms +step:557/1670 train_time:53987ms step_avg:96.93ms +step:558/1670 train_time:54084ms step_avg:96.92ms +step:559/1670 train_time:54180ms step_avg:96.92ms +step:560/1670 train_time:54277ms step_avg:96.92ms +step:561/1670 train_time:54374ms step_avg:96.92ms +step:562/1670 train_time:54471ms step_avg:96.92ms +step:563/1670 train_time:54567ms step_avg:96.92ms +step:564/1670 train_time:54665ms step_avg:96.92ms +step:565/1670 train_time:54761ms step_avg:96.92ms +step:566/1670 train_time:54857ms step_avg:96.92ms +step:567/1670 train_time:54954ms step_avg:96.92ms +step:568/1670 train_time:55052ms step_avg:96.92ms +step:569/1670 train_time:55150ms step_avg:96.92ms +step:570/1670 train_time:55247ms step_avg:96.92ms +step:571/1670 train_time:55344ms step_avg:96.92ms +step:572/1670 train_time:55441ms step_avg:96.92ms +step:573/1670 train_time:55537ms step_avg:96.92ms +step:574/1670 train_time:55635ms step_avg:96.92ms +step:575/1670 train_time:55732ms step_avg:96.93ms +step:576/1670 train_time:55830ms step_avg:96.93ms +step:577/1670 train_time:55927ms step_avg:96.93ms +step:578/1670 train_time:56024ms step_avg:96.93ms +step:579/1670 train_time:56121ms step_avg:96.93ms +step:580/1670 train_time:56217ms step_avg:96.93ms +step:581/1670 train_time:56315ms step_avg:96.93ms +step:582/1670 train_time:56412ms step_avg:96.93ms +step:583/1670 train_time:56510ms step_avg:96.93ms +step:584/1670 train_time:56607ms step_avg:96.93ms +step:585/1670 train_time:56704ms step_avg:96.93ms +step:586/1670 train_time:56801ms step_avg:96.93ms +step:587/1670 train_time:56898ms step_avg:96.93ms +step:588/1670 train_time:56995ms step_avg:96.93ms +step:589/1670 train_time:57092ms step_avg:96.93ms +step:590/1670 train_time:57190ms step_avg:96.93ms +step:591/1670 train_time:57287ms step_avg:96.93ms +step:592/1670 train_time:57384ms step_avg:96.93ms +step:593/1670 train_time:57480ms step_avg:96.93ms +step:594/1670 train_time:57576ms step_avg:96.93ms +step:595/1670 train_time:57674ms step_avg:96.93ms +step:596/1670 train_time:57771ms step_avg:96.93ms +step:597/1670 train_time:57869ms step_avg:96.93ms +step:598/1670 train_time:57966ms step_avg:96.93ms +step:599/1670 train_time:58062ms step_avg:96.93ms +step:600/1670 train_time:58159ms step_avg:96.93ms +step:601/1670 train_time:58256ms step_avg:96.93ms +step:602/1670 train_time:58355ms step_avg:96.94ms +step:603/1670 train_time:58452ms step_avg:96.94ms +step:604/1670 train_time:58550ms step_avg:96.94ms +step:605/1670 train_time:58646ms step_avg:96.94ms +step:606/1670 train_time:58744ms step_avg:96.94ms +step:607/1670 train_time:58841ms step_avg:96.94ms +step:608/1670 train_time:58937ms step_avg:96.94ms +step:609/1670 train_time:59034ms step_avg:96.94ms +step:610/1670 train_time:59132ms step_avg:96.94ms +step:611/1670 train_time:59229ms step_avg:96.94ms +step:612/1670 train_time:59327ms step_avg:96.94ms +step:613/1670 train_time:59424ms step_avg:96.94ms +step:614/1670 train_time:59521ms step_avg:96.94ms +step:615/1670 train_time:59617ms step_avg:96.94ms +step:616/1670 train_time:59714ms step_avg:96.94ms +step:617/1670 train_time:59812ms step_avg:96.94ms +step:618/1670 train_time:59909ms step_avg:96.94ms +step:619/1670 train_time:60007ms step_avg:96.94ms +step:620/1670 train_time:60103ms step_avg:96.94ms +step:621/1670 train_time:60199ms step_avg:96.94ms +step:622/1670 train_time:60296ms step_avg:96.94ms +step:623/1670 train_time:60393ms step_avg:96.94ms +step:624/1670 train_time:60490ms step_avg:96.94ms +step:625/1670 train_time:60588ms step_avg:96.94ms +step:625/1670 val_loss:3.6144 train_time:60684ms step_avg:97.09ms +step:626/1670 train_time:60705ms step_avg:96.97ms +step:627/1670 train_time:60795ms step_avg:96.96ms +step:628/1670 train_time:60895ms step_avg:96.97ms +step:629/1670 train_time:60991ms step_avg:96.96ms +step:630/1670 train_time:61086ms step_avg:96.96ms +step:631/1670 train_time:61182ms step_avg:96.96ms +step:632/1670 train_time:61277ms step_avg:96.96ms +step:633/1670 train_time:61373ms step_avg:96.96ms +step:634/1670 train_time:61469ms step_avg:96.95ms +step:635/1670 train_time:61565ms step_avg:96.95ms +step:636/1670 train_time:61662ms step_avg:96.95ms +step:637/1670 train_time:61763ms step_avg:96.96ms +step:638/1670 train_time:61863ms step_avg:96.96ms +step:639/1670 train_time:62238ms step_avg:97.40ms +step:640/1670 train_time:62338ms step_avg:97.40ms +step:641/1670 train_time:62433ms step_avg:97.40ms +step:642/1670 train_time:62529ms step_avg:97.40ms +step:643/1670 train_time:62625ms step_avg:97.39ms +step:644/1670 train_time:62721ms step_avg:97.39ms +step:645/1670 train_time:62816ms step_avg:97.39ms +step:646/1670 train_time:62912ms step_avg:97.39ms +step:647/1670 train_time:63008ms step_avg:97.38ms +step:648/1670 train_time:63104ms step_avg:97.38ms +step:649/1670 train_time:63201ms step_avg:97.38ms +step:650/1670 train_time:63305ms step_avg:97.39ms +step:651/1670 train_time:63406ms step_avg:97.40ms +step:652/1670 train_time:63505ms step_avg:97.40ms +step:653/1670 train_time:63602ms step_avg:97.40ms +step:654/1670 train_time:63699ms step_avg:97.40ms +step:655/1670 train_time:63795ms step_avg:97.40ms +step:656/1670 train_time:63890ms step_avg:97.39ms +step:657/1670 train_time:63985ms step_avg:97.39ms +step:658/1670 train_time:64082ms step_avg:97.39ms +step:659/1670 train_time:64178ms step_avg:97.39ms +step:660/1670 train_time:64277ms step_avg:97.39ms +step:661/1670 train_time:64374ms step_avg:97.39ms +step:662/1670 train_time:64471ms step_avg:97.39ms +step:663/1670 train_time:64569ms step_avg:97.39ms +step:664/1670 train_time:64667ms step_avg:97.39ms +step:665/1670 train_time:64764ms step_avg:97.39ms +step:666/1670 train_time:64861ms step_avg:97.39ms +step:667/1670 train_time:64957ms step_avg:97.39ms +step:668/1670 train_time:65053ms step_avg:97.38ms +step:669/1670 train_time:65149ms step_avg:97.38ms +step:670/1670 train_time:65247ms step_avg:97.38ms +step:671/1670 train_time:65346ms step_avg:97.39ms +step:672/1670 train_time:65444ms step_avg:97.39ms +step:673/1670 train_time:65542ms step_avg:97.39ms +step:674/1670 train_time:65640ms step_avg:97.39ms +step:675/1670 train_time:65737ms step_avg:97.39ms +step:676/1670 train_time:65833ms step_avg:97.39ms +step:677/1670 train_time:65929ms step_avg:97.38ms +step:678/1670 train_time:66025ms step_avg:97.38ms +step:679/1670 train_time:66123ms step_avg:97.38ms +step:680/1670 train_time:66220ms step_avg:97.38ms +step:681/1670 train_time:66317ms step_avg:97.38ms +step:682/1670 train_time:66414ms step_avg:97.38ms +step:683/1670 train_time:66512ms step_avg:97.38ms +step:684/1670 train_time:66609ms step_avg:97.38ms +step:685/1670 train_time:66709ms step_avg:97.39ms +step:686/1670 train_time:66808ms step_avg:97.39ms +step:687/1670 train_time:66905ms step_avg:97.39ms +step:688/1670 train_time:67001ms step_avg:97.39ms +step:689/1670 train_time:67097ms step_avg:97.38ms +step:690/1670 train_time:67193ms step_avg:97.38ms +step:691/1670 train_time:67290ms step_avg:97.38ms +step:692/1670 train_time:67387ms step_avg:97.38ms +step:693/1670 train_time:67485ms step_avg:97.38ms +step:694/1670 train_time:67583ms step_avg:97.38ms +step:695/1670 train_time:67681ms step_avg:97.38ms +step:696/1670 train_time:67778ms step_avg:97.38ms +step:697/1670 train_time:67874ms step_avg:97.38ms +step:698/1670 train_time:67971ms step_avg:97.38ms +step:699/1670 train_time:68068ms step_avg:97.38ms +step:700/1670 train_time:68165ms step_avg:97.38ms +step:701/1670 train_time:68262ms step_avg:97.38ms +step:702/1670 train_time:68360ms step_avg:97.38ms +step:703/1670 train_time:68457ms step_avg:97.38ms +step:704/1670 train_time:68554ms step_avg:97.38ms +step:705/1670 train_time:68650ms step_avg:97.38ms +step:706/1670 train_time:68748ms step_avg:97.38ms +step:707/1670 train_time:68847ms step_avg:97.38ms +step:708/1670 train_time:68944ms step_avg:97.38ms +step:709/1670 train_time:69041ms step_avg:97.38ms +step:710/1670 train_time:69137ms step_avg:97.38ms +step:711/1670 train_time:69234ms step_avg:97.38ms +step:712/1670 train_time:69331ms step_avg:97.37ms +step:713/1670 train_time:69428ms step_avg:97.37ms +step:714/1670 train_time:69525ms step_avg:97.37ms +step:715/1670 train_time:69622ms step_avg:97.37ms +step:716/1670 train_time:69720ms step_avg:97.37ms +step:717/1670 train_time:69817ms step_avg:97.37ms +step:718/1670 train_time:69913ms step_avg:97.37ms +step:719/1670 train_time:70010ms step_avg:97.37ms +step:720/1670 train_time:70107ms step_avg:97.37ms +step:721/1670 train_time:70206ms step_avg:97.37ms +step:722/1670 train_time:70303ms step_avg:97.37ms +step:723/1670 train_time:70401ms step_avg:97.37ms +step:724/1670 train_time:70498ms step_avg:97.37ms +step:725/1670 train_time:70595ms step_avg:97.37ms +step:726/1670 train_time:70691ms step_avg:97.37ms +step:727/1670 train_time:70789ms step_avg:97.37ms +step:728/1670 train_time:70887ms step_avg:97.37ms +step:729/1670 train_time:70984ms step_avg:97.37ms +step:730/1670 train_time:71081ms step_avg:97.37ms +step:731/1670 train_time:71177ms step_avg:97.37ms +step:732/1670 train_time:71274ms step_avg:97.37ms +step:733/1670 train_time:71370ms step_avg:97.37ms +step:734/1670 train_time:71469ms step_avg:97.37ms +step:735/1670 train_time:71566ms step_avg:97.37ms +step:736/1670 train_time:71663ms step_avg:97.37ms +step:737/1670 train_time:71761ms step_avg:97.37ms +step:738/1670 train_time:71857ms step_avg:97.37ms +step:739/1670 train_time:71954ms step_avg:97.37ms +step:740/1670 train_time:72051ms step_avg:97.37ms +step:741/1670 train_time:72148ms step_avg:97.37ms +step:742/1670 train_time:72246ms step_avg:97.37ms +step:743/1670 train_time:72343ms step_avg:97.37ms +step:744/1670 train_time:72441ms step_avg:97.37ms +step:745/1670 train_time:72538ms step_avg:97.37ms +step:746/1670 train_time:72634ms step_avg:97.36ms +step:747/1670 train_time:72730ms step_avg:97.36ms +step:748/1670 train_time:72827ms step_avg:97.36ms +step:749/1670 train_time:72925ms step_avg:97.36ms +step:750/1670 train_time:73023ms step_avg:97.36ms +step:750/1670 val_loss:3.5600 train_time:73120ms step_avg:97.49ms +step:751/1670 train_time:73141ms step_avg:97.39ms +step:752/1670 train_time:73225ms step_avg:97.37ms +step:753/1670 train_time:73324ms step_avg:97.38ms +step:754/1670 train_time:73422ms step_avg:97.38ms +step:755/1670 train_time:73518ms step_avg:97.37ms +step:756/1670 train_time:73614ms step_avg:97.37ms +step:757/1670 train_time:73710ms step_avg:97.37ms +step:758/1670 train_time:73806ms step_avg:97.37ms +step:759/1670 train_time:73902ms step_avg:97.37ms +step:760/1670 train_time:73999ms step_avg:97.37ms +step:761/1670 train_time:74097ms step_avg:97.37ms +step:762/1670 train_time:74198ms step_avg:97.37ms +step:763/1670 train_time:74298ms step_avg:97.38ms +step:764/1670 train_time:74397ms step_avg:97.38ms +step:765/1670 train_time:74494ms step_avg:97.38ms +step:766/1670 train_time:74590ms step_avg:97.38ms +step:767/1670 train_time:74686ms step_avg:97.37ms +step:768/1670 train_time:74782ms step_avg:97.37ms +step:769/1670 train_time:74879ms step_avg:97.37ms +step:770/1670 train_time:74976ms step_avg:97.37ms +step:771/1670 train_time:75073ms step_avg:97.37ms +step:772/1670 train_time:75171ms step_avg:97.37ms +step:773/1670 train_time:75268ms step_avg:97.37ms +step:774/1670 train_time:75367ms step_avg:97.37ms +step:775/1670 train_time:75464ms step_avg:97.37ms +step:776/1670 train_time:75561ms step_avg:97.37ms +step:777/1670 train_time:75660ms step_avg:97.37ms +step:778/1670 train_time:75756ms step_avg:97.37ms +step:779/1670 train_time:75853ms step_avg:97.37ms +step:780/1670 train_time:75949ms step_avg:97.37ms +step:781/1670 train_time:76045ms step_avg:97.37ms +step:782/1670 train_time:76142ms step_avg:97.37ms +step:783/1670 train_time:76241ms step_avg:97.37ms +step:784/1670 train_time:76339ms step_avg:97.37ms +step:785/1670 train_time:76437ms step_avg:97.37ms +step:786/1670 train_time:76534ms step_avg:97.37ms +step:787/1670 train_time:76631ms step_avg:97.37ms +step:788/1670 train_time:76727ms step_avg:97.37ms +step:789/1670 train_time:76824ms step_avg:97.37ms +step:790/1670 train_time:76921ms step_avg:97.37ms +step:791/1670 train_time:77018ms step_avg:97.37ms +step:792/1670 train_time:77115ms step_avg:97.37ms +step:793/1670 train_time:77214ms step_avg:97.37ms +step:794/1670 train_time:77310ms step_avg:97.37ms +step:795/1670 train_time:77406ms step_avg:97.37ms +step:796/1670 train_time:77504ms step_avg:97.37ms +step:797/1670 train_time:77602ms step_avg:97.37ms +step:798/1670 train_time:77700ms step_avg:97.37ms +step:799/1670 train_time:77796ms step_avg:97.37ms +step:800/1670 train_time:77894ms step_avg:97.37ms +step:801/1670 train_time:77991ms step_avg:97.37ms +step:802/1670 train_time:78087ms step_avg:97.37ms +step:803/1670 train_time:78184ms step_avg:97.37ms +step:804/1670 train_time:78283ms step_avg:97.37ms +step:805/1670 train_time:78381ms step_avg:97.37ms +step:806/1670 train_time:78479ms step_avg:97.37ms +step:807/1670 train_time:78577ms step_avg:97.37ms +step:808/1670 train_time:78674ms step_avg:97.37ms +step:809/1670 train_time:78771ms step_avg:97.37ms +step:810/1670 train_time:78867ms step_avg:97.37ms +step:811/1670 train_time:78963ms step_avg:97.37ms +step:812/1670 train_time:79061ms step_avg:97.37ms +step:813/1670 train_time:79159ms step_avg:97.37ms +step:814/1670 train_time:79256ms step_avg:97.37ms +step:815/1670 train_time:79355ms step_avg:97.37ms +step:816/1670 train_time:79452ms step_avg:97.37ms +step:817/1670 train_time:79548ms step_avg:97.37ms +step:818/1670 train_time:79644ms step_avg:97.36ms +step:819/1670 train_time:79741ms step_avg:97.36ms +step:820/1670 train_time:79838ms step_avg:97.36ms +step:821/1670 train_time:79936ms step_avg:97.36ms +step:822/1670 train_time:80033ms step_avg:97.36ms +step:823/1670 train_time:80129ms step_avg:97.36ms +step:824/1670 train_time:80225ms step_avg:97.36ms +step:825/1670 train_time:80323ms step_avg:97.36ms +step:826/1670 train_time:80423ms step_avg:97.36ms +step:827/1670 train_time:80521ms step_avg:97.37ms +step:828/1670 train_time:80618ms step_avg:97.37ms +step:829/1670 train_time:80715ms step_avg:97.36ms +step:830/1670 train_time:80811ms step_avg:97.36ms +step:831/1670 train_time:80907ms step_avg:97.36ms +step:832/1670 train_time:81004ms step_avg:97.36ms +step:833/1670 train_time:81101ms step_avg:97.36ms +step:834/1670 train_time:81199ms step_avg:97.36ms +step:835/1670 train_time:81297ms step_avg:97.36ms +step:836/1670 train_time:81394ms step_avg:97.36ms +step:837/1670 train_time:81491ms step_avg:97.36ms +step:838/1670 train_time:81588ms step_avg:97.36ms +step:839/1670 train_time:81685ms step_avg:97.36ms +step:840/1670 train_time:81783ms step_avg:97.36ms +step:841/1670 train_time:81881ms step_avg:97.36ms +step:842/1670 train_time:81978ms step_avg:97.36ms +step:843/1670 train_time:82075ms step_avg:97.36ms +step:844/1670 train_time:82172ms step_avg:97.36ms +step:845/1670 train_time:82268ms step_avg:97.36ms +step:846/1670 train_time:82365ms step_avg:97.36ms +step:847/1670 train_time:82463ms step_avg:97.36ms +step:848/1670 train_time:82560ms step_avg:97.36ms +step:849/1670 train_time:82658ms step_avg:97.36ms +step:850/1670 train_time:82756ms step_avg:97.36ms +step:851/1670 train_time:83061ms step_avg:97.60ms +step:852/1670 train_time:83183ms step_avg:97.63ms +step:853/1670 train_time:83278ms step_avg:97.63ms +step:854/1670 train_time:83373ms step_avg:97.63ms +step:855/1670 train_time:83469ms step_avg:97.62ms +step:856/1670 train_time:83564ms step_avg:97.62ms +step:857/1670 train_time:83660ms step_avg:97.62ms +step:858/1670 train_time:83756ms step_avg:97.62ms +step:859/1670 train_time:83852ms step_avg:97.62ms +step:860/1670 train_time:83948ms step_avg:97.61ms +step:861/1670 train_time:84050ms step_avg:97.62ms +step:862/1670 train_time:84148ms step_avg:97.62ms +step:863/1670 train_time:84248ms step_avg:97.62ms +step:864/1670 train_time:84345ms step_avg:97.62ms +step:865/1670 train_time:84442ms step_avg:97.62ms +step:866/1670 train_time:84539ms step_avg:97.62ms +step:867/1670 train_time:84635ms step_avg:97.62ms +step:868/1670 train_time:84730ms step_avg:97.62ms +step:869/1670 train_time:84826ms step_avg:97.61ms +step:870/1670 train_time:84924ms step_avg:97.61ms +step:871/1670 train_time:85022ms step_avg:97.61ms +step:872/1670 train_time:85122ms step_avg:97.62ms +step:873/1670 train_time:85221ms step_avg:97.62ms +step:874/1670 train_time:85319ms step_avg:97.62ms +step:875/1670 train_time:85417ms step_avg:97.62ms +step:875/1670 val_loss:3.5194 train_time:85513ms step_avg:97.73ms +step:876/1670 train_time:85533ms step_avg:97.64ms +step:877/1670 train_time:85616ms step_avg:97.62ms +step:878/1670 train_time:85715ms step_avg:97.63ms +step:879/1670 train_time:85812ms step_avg:97.62ms +step:880/1670 train_time:85907ms step_avg:97.62ms +step:881/1670 train_time:86003ms step_avg:97.62ms +step:882/1670 train_time:86099ms step_avg:97.62ms +step:883/1670 train_time:86195ms step_avg:97.62ms +step:884/1670 train_time:86292ms step_avg:97.62ms +step:885/1670 train_time:86387ms step_avg:97.61ms +step:886/1670 train_time:86485ms step_avg:97.61ms +step:887/1670 train_time:86585ms step_avg:97.62ms +step:888/1670 train_time:86683ms step_avg:97.62ms +step:889/1670 train_time:86782ms step_avg:97.62ms +step:890/1670 train_time:86880ms step_avg:97.62ms +step:891/1670 train_time:86976ms step_avg:97.62ms +step:892/1670 train_time:87073ms step_avg:97.61ms +step:893/1670 train_time:87168ms step_avg:97.61ms +step:894/1670 train_time:87264ms step_avg:97.61ms +step:895/1670 train_time:87361ms step_avg:97.61ms +step:896/1670 train_time:87459ms step_avg:97.61ms +step:897/1670 train_time:87557ms step_avg:97.61ms +step:898/1670 train_time:87655ms step_avg:97.61ms +step:899/1670 train_time:87752ms step_avg:97.61ms +step:900/1670 train_time:87849ms step_avg:97.61ms +step:901/1670 train_time:87945ms step_avg:97.61ms +step:902/1670 train_time:88044ms step_avg:97.61ms +step:903/1670 train_time:88141ms step_avg:97.61ms +step:904/1670 train_time:88237ms step_avg:97.61ms +step:905/1670 train_time:88334ms step_avg:97.61ms +step:906/1670 train_time:88431ms step_avg:97.61ms +step:907/1670 train_time:88528ms step_avg:97.61ms +step:908/1670 train_time:88625ms step_avg:97.60ms +step:909/1670 train_time:88722ms step_avg:97.60ms +step:910/1670 train_time:88820ms step_avg:97.60ms +step:911/1670 train_time:88918ms step_avg:97.60ms +step:912/1670 train_time:89015ms step_avg:97.60ms +step:913/1670 train_time:89112ms step_avg:97.60ms +step:914/1670 train_time:89208ms step_avg:97.60ms +step:915/1670 train_time:89305ms step_avg:97.60ms +step:916/1670 train_time:89403ms step_avg:97.60ms +step:917/1670 train_time:89501ms step_avg:97.60ms +step:918/1670 train_time:89599ms step_avg:97.60ms +step:919/1670 train_time:89698ms step_avg:97.60ms +step:920/1670 train_time:89795ms step_avg:97.60ms +step:921/1670 train_time:89891ms step_avg:97.60ms +step:922/1670 train_time:89988ms step_avg:97.60ms +step:923/1670 train_time:90085ms step_avg:97.60ms +step:924/1670 train_time:90183ms step_avg:97.60ms +step:925/1670 train_time:90280ms step_avg:97.60ms +step:926/1670 train_time:90377ms step_avg:97.60ms +step:927/1670 train_time:90474ms step_avg:97.60ms +step:928/1670 train_time:90572ms step_avg:97.60ms +step:929/1670 train_time:90669ms step_avg:97.60ms +step:930/1670 train_time:90767ms step_avg:97.60ms +step:931/1670 train_time:90864ms step_avg:97.60ms +step:932/1670 train_time:90961ms step_avg:97.60ms +step:933/1670 train_time:91058ms step_avg:97.60ms +step:934/1670 train_time:91155ms step_avg:97.60ms +step:935/1670 train_time:91251ms step_avg:97.60ms +step:936/1670 train_time:91348ms step_avg:97.59ms +step:937/1670 train_time:91445ms step_avg:97.59ms +step:938/1670 train_time:91544ms step_avg:97.59ms +step:939/1670 train_time:91642ms step_avg:97.59ms +step:940/1670 train_time:91739ms step_avg:97.59ms +step:941/1670 train_time:91837ms step_avg:97.59ms +step:942/1670 train_time:91934ms step_avg:97.59ms +step:943/1670 train_time:92030ms step_avg:97.59ms +step:944/1670 train_time:92126ms step_avg:97.59ms +step:945/1670 train_time:92223ms step_avg:97.59ms +step:946/1670 train_time:92320ms step_avg:97.59ms +step:947/1670 train_time:92418ms step_avg:97.59ms +step:948/1670 train_time:92515ms step_avg:97.59ms +step:949/1670 train_time:92613ms step_avg:97.59ms +step:950/1670 train_time:92709ms step_avg:97.59ms +step:951/1670 train_time:92806ms step_avg:97.59ms +step:952/1670 train_time:92903ms step_avg:97.59ms +step:953/1670 train_time:93001ms step_avg:97.59ms +step:954/1670 train_time:93098ms step_avg:97.59ms +step:955/1670 train_time:93195ms step_avg:97.59ms +step:956/1670 train_time:93291ms step_avg:97.59ms +step:957/1670 train_time:93388ms step_avg:97.58ms +step:958/1670 train_time:93484ms step_avg:97.58ms +step:959/1670 train_time:93581ms step_avg:97.58ms +step:960/1670 train_time:93680ms step_avg:97.58ms +step:961/1670 train_time:93779ms step_avg:97.58ms +step:962/1670 train_time:93876ms step_avg:97.58ms +step:963/1670 train_time:93973ms step_avg:97.58ms +step:964/1670 train_time:94070ms step_avg:97.58ms +step:965/1670 train_time:94167ms step_avg:97.58ms +step:966/1670 train_time:94264ms step_avg:97.58ms +step:967/1670 train_time:94360ms step_avg:97.58ms +step:968/1670 train_time:94458ms step_avg:97.58ms +step:969/1670 train_time:94555ms step_avg:97.58ms +step:970/1670 train_time:94652ms step_avg:97.58ms +step:971/1670 train_time:94748ms step_avg:97.58ms +step:972/1670 train_time:94846ms step_avg:97.58ms +step:973/1670 train_time:94944ms step_avg:97.58ms +step:974/1670 train_time:95042ms step_avg:97.58ms +step:975/1670 train_time:95139ms step_avg:97.58ms +step:976/1670 train_time:95236ms step_avg:97.58ms +step:977/1670 train_time:95333ms step_avg:97.58ms +step:978/1670 train_time:95430ms step_avg:97.58ms +step:979/1670 train_time:95527ms step_avg:97.58ms +step:980/1670 train_time:95624ms step_avg:97.58ms +step:981/1670 train_time:95722ms step_avg:97.58ms +step:982/1670 train_time:95819ms step_avg:97.58ms +step:983/1670 train_time:95917ms step_avg:97.58ms +step:984/1670 train_time:96014ms step_avg:97.58ms +step:985/1670 train_time:96111ms step_avg:97.57ms +step:986/1670 train_time:96207ms step_avg:97.57ms +step:987/1670 train_time:96304ms step_avg:97.57ms +step:988/1670 train_time:96402ms step_avg:97.57ms +step:989/1670 train_time:96501ms step_avg:97.57ms +step:990/1670 train_time:96599ms step_avg:97.57ms +step:991/1670 train_time:96696ms step_avg:97.57ms +step:992/1670 train_time:96793ms step_avg:97.57ms +step:993/1670 train_time:96891ms step_avg:97.57ms +step:994/1670 train_time:96987ms step_avg:97.57ms +step:995/1670 train_time:97085ms step_avg:97.57ms +step:996/1670 train_time:97182ms step_avg:97.57ms +step:997/1670 train_time:97279ms step_avg:97.57ms +step:998/1670 train_time:97377ms step_avg:97.57ms +step:999/1670 train_time:97473ms step_avg:97.57ms +step:1000/1670 train_time:97571ms step_avg:97.57ms +step:1000/1670 val_loss:3.4779 train_time:97667ms step_avg:97.67ms +step:1001/1670 train_time:97688ms step_avg:97.59ms +step:1002/1670 train_time:97768ms step_avg:97.57ms +step:1003/1670 train_time:97868ms step_avg:97.58ms +step:1004/1670 train_time:97964ms step_avg:97.57ms +step:1005/1670 train_time:98060ms step_avg:97.57ms +step:1006/1670 train_time:98156ms step_avg:97.57ms +step:1007/1670 train_time:98252ms step_avg:97.57ms +step:1008/1670 train_time:98348ms step_avg:97.57ms +step:1009/1670 train_time:98445ms step_avg:97.57ms +step:1010/1670 train_time:98541ms step_avg:97.57ms +step:1011/1670 train_time:98640ms step_avg:97.57ms +step:1012/1670 train_time:98740ms step_avg:97.57ms +step:1013/1670 train_time:98841ms step_avg:97.57ms +step:1014/1670 train_time:98939ms step_avg:97.57ms +step:1015/1670 train_time:99036ms step_avg:97.57ms +step:1016/1670 train_time:99133ms step_avg:97.57ms +step:1017/1670 train_time:99229ms step_avg:97.57ms +step:1018/1670 train_time:99325ms step_avg:97.57ms +step:1019/1670 train_time:99420ms step_avg:97.57ms +step:1020/1670 train_time:99517ms step_avg:97.57ms +step:1021/1670 train_time:99616ms step_avg:97.57ms +step:1022/1670 train_time:99714ms step_avg:97.57ms +step:1023/1670 train_time:99813ms step_avg:97.57ms +step:1024/1670 train_time:99912ms step_avg:97.57ms +step:1025/1670 train_time:100010ms step_avg:97.57ms +step:1026/1670 train_time:100106ms step_avg:97.57ms +step:1027/1670 train_time:100202ms step_avg:97.57ms +step:1028/1670 train_time:100299ms step_avg:97.57ms +step:1029/1670 train_time:100395ms step_avg:97.57ms +step:1030/1670 train_time:100492ms step_avg:97.57ms +step:1031/1670 train_time:100589ms step_avg:97.56ms +step:1032/1670 train_time:100686ms step_avg:97.56ms +step:1033/1670 train_time:100784ms step_avg:97.56ms +step:1034/1670 train_time:100882ms step_avg:97.56ms +step:1035/1670 train_time:100979ms step_avg:97.56ms +step:1036/1670 train_time:101077ms step_avg:97.56ms +step:1037/1670 train_time:101175ms step_avg:97.56ms +step:1038/1670 train_time:101271ms step_avg:97.56ms +step:1039/1670 train_time:101368ms step_avg:97.56ms +step:1040/1670 train_time:101463ms step_avg:97.56ms +step:1041/1670 train_time:101561ms step_avg:97.56ms +step:1042/1670 train_time:101658ms step_avg:97.56ms +step:1043/1670 train_time:101757ms step_avg:97.56ms +step:1044/1670 train_time:101855ms step_avg:97.56ms +step:1045/1670 train_time:101953ms step_avg:97.56ms +step:1046/1670 train_time:102050ms step_avg:97.56ms +step:1047/1670 train_time:102146ms step_avg:97.56ms +step:1048/1670 train_time:102243ms step_avg:97.56ms +step:1049/1670 train_time:102340ms step_avg:97.56ms +step:1050/1670 train_time:102438ms step_avg:97.56ms +step:1051/1670 train_time:102535ms step_avg:97.56ms +step:1052/1670 train_time:102633ms step_avg:97.56ms +step:1053/1670 train_time:102731ms step_avg:97.56ms +step:1054/1670 train_time:102829ms step_avg:97.56ms +step:1055/1670 train_time:102925ms step_avg:97.56ms +step:1056/1670 train_time:103023ms step_avg:97.56ms +step:1057/1670 train_time:103120ms step_avg:97.56ms +step:1058/1670 train_time:103217ms step_avg:97.56ms +step:1059/1670 train_time:103315ms step_avg:97.56ms +step:1060/1670 train_time:103413ms step_avg:97.56ms +step:1061/1670 train_time:103510ms step_avg:97.56ms +step:1062/1670 train_time:103781ms step_avg:97.72ms +step:1063/1670 train_time:103856ms step_avg:97.70ms +step:1064/1670 train_time:103952ms step_avg:97.70ms +step:1065/1670 train_time:104048ms step_avg:97.70ms +step:1066/1670 train_time:104144ms step_avg:97.70ms +step:1067/1670 train_time:104239ms step_avg:97.69ms +step:1068/1670 train_time:104336ms step_avg:97.69ms +step:1069/1670 train_time:104432ms step_avg:97.69ms +step:1070/1670 train_time:104528ms step_avg:97.69ms +step:1071/1670 train_time:104625ms step_avg:97.69ms +step:1072/1670 train_time:104726ms step_avg:97.69ms +step:1073/1670 train_time:104825ms step_avg:97.69ms +step:1074/1670 train_time:104922ms step_avg:97.69ms +step:1075/1670 train_time:105019ms step_avg:97.69ms +step:1076/1670 train_time:105116ms step_avg:97.69ms +step:1077/1670 train_time:105212ms step_avg:97.69ms +step:1078/1670 train_time:105307ms step_avg:97.69ms +step:1079/1670 train_time:105403ms step_avg:97.69ms +step:1080/1670 train_time:105498ms step_avg:97.68ms +step:1081/1670 train_time:105597ms step_avg:97.68ms +step:1082/1670 train_time:105697ms step_avg:97.69ms +step:1083/1670 train_time:105798ms step_avg:97.69ms +step:1084/1670 train_time:105896ms step_avg:97.69ms +step:1085/1670 train_time:105993ms step_avg:97.69ms +step:1086/1670 train_time:106090ms step_avg:97.69ms +step:1087/1670 train_time:106187ms step_avg:97.69ms +step:1088/1670 train_time:106283ms step_avg:97.69ms +step:1089/1670 train_time:106378ms step_avg:97.68ms +step:1090/1670 train_time:106475ms step_avg:97.68ms +step:1091/1670 train_time:106572ms step_avg:97.68ms +step:1092/1670 train_time:106670ms step_avg:97.68ms +step:1093/1670 train_time:106768ms step_avg:97.68ms +step:1094/1670 train_time:106865ms step_avg:97.68ms +step:1095/1670 train_time:106962ms step_avg:97.68ms +step:1096/1670 train_time:107061ms step_avg:97.68ms +step:1097/1670 train_time:107158ms step_avg:97.68ms +step:1098/1670 train_time:107255ms step_avg:97.68ms +step:1099/1670 train_time:107351ms step_avg:97.68ms +step:1100/1670 train_time:107448ms step_avg:97.68ms +step:1101/1670 train_time:107544ms step_avg:97.68ms +step:1102/1670 train_time:107641ms step_avg:97.68ms +step:1103/1670 train_time:107739ms step_avg:97.68ms +step:1104/1670 train_time:107838ms step_avg:97.68ms +step:1105/1670 train_time:107935ms step_avg:97.68ms +step:1106/1670 train_time:108033ms step_avg:97.68ms +step:1107/1670 train_time:108131ms step_avg:97.68ms +step:1108/1670 train_time:108228ms step_avg:97.68ms +step:1109/1670 train_time:108324ms step_avg:97.68ms +step:1110/1670 train_time:108421ms step_avg:97.68ms +step:1111/1670 train_time:108517ms step_avg:97.68ms +step:1112/1670 train_time:108615ms step_avg:97.68ms +step:1113/1670 train_time:108712ms step_avg:97.67ms +step:1114/1670 train_time:108810ms step_avg:97.67ms +step:1115/1670 train_time:108907ms step_avg:97.67ms +step:1116/1670 train_time:109005ms step_avg:97.68ms +step:1117/1670 train_time:109103ms step_avg:97.68ms +step:1118/1670 train_time:109201ms step_avg:97.68ms +step:1119/1670 train_time:109298ms step_avg:97.67ms +step:1120/1670 train_time:109396ms step_avg:97.68ms +step:1121/1670 train_time:109494ms step_avg:97.68ms +step:1122/1670 train_time:109593ms step_avg:97.68ms +step:1123/1670 train_time:109690ms step_avg:97.68ms +step:1124/1670 train_time:109788ms step_avg:97.68ms +step:1125/1670 train_time:109885ms step_avg:97.68ms +step:1125/1670 val_loss:3.4237 train_time:109982ms step_avg:97.76ms +step:1126/1670 train_time:110003ms step_avg:97.69ms +step:1127/1670 train_time:110090ms step_avg:97.68ms +step:1128/1670 train_time:110187ms step_avg:97.68ms +step:1129/1670 train_time:110284ms step_avg:97.68ms +step:1130/1670 train_time:110380ms step_avg:97.68ms +step:1131/1670 train_time:110476ms step_avg:97.68ms +step:1132/1670 train_time:110574ms step_avg:97.68ms +step:1133/1670 train_time:110670ms step_avg:97.68ms +step:1134/1670 train_time:110767ms step_avg:97.68ms +step:1135/1670 train_time:110864ms step_avg:97.68ms +step:1136/1670 train_time:110964ms step_avg:97.68ms +step:1137/1670 train_time:111064ms step_avg:97.68ms +step:1138/1670 train_time:111163ms step_avg:97.68ms +step:1139/1670 train_time:111261ms step_avg:97.68ms +step:1140/1670 train_time:111358ms step_avg:97.68ms +step:1141/1670 train_time:111455ms step_avg:97.68ms +step:1142/1670 train_time:111552ms step_avg:97.68ms +step:1143/1670 train_time:111649ms step_avg:97.68ms +step:1144/1670 train_time:111745ms step_avg:97.68ms +step:1145/1670 train_time:111841ms step_avg:97.68ms +step:1146/1670 train_time:111940ms step_avg:97.68ms +step:1147/1670 train_time:112040ms step_avg:97.68ms +step:1148/1670 train_time:112139ms step_avg:97.68ms +step:1149/1670 train_time:112238ms step_avg:97.68ms +step:1150/1670 train_time:112336ms step_avg:97.68ms +step:1151/1670 train_time:112434ms step_avg:97.68ms +step:1152/1670 train_time:112531ms step_avg:97.68ms +step:1153/1670 train_time:112628ms step_avg:97.68ms +step:1154/1670 train_time:112724ms step_avg:97.68ms +step:1155/1670 train_time:112821ms step_avg:97.68ms +step:1156/1670 train_time:112918ms step_avg:97.68ms +step:1157/1670 train_time:113017ms step_avg:97.68ms +step:1158/1670 train_time:113117ms step_avg:97.68ms +step:1159/1670 train_time:113217ms step_avg:97.68ms +step:1160/1670 train_time:113316ms step_avg:97.69ms +step:1161/1670 train_time:113415ms step_avg:97.69ms +step:1162/1670 train_time:113513ms step_avg:97.69ms +step:1163/1670 train_time:113609ms step_avg:97.69ms +step:1164/1670 train_time:113707ms step_avg:97.69ms +step:1165/1670 train_time:113804ms step_avg:97.69ms +step:1166/1670 train_time:113901ms step_avg:97.69ms +step:1167/1670 train_time:113999ms step_avg:97.69ms +step:1168/1670 train_time:114097ms step_avg:97.69ms +step:1169/1670 train_time:114195ms step_avg:97.69ms +step:1170/1670 train_time:114294ms step_avg:97.69ms +step:1171/1670 train_time:114393ms step_avg:97.69ms +step:1172/1670 train_time:114490ms step_avg:97.69ms +step:1173/1670 train_time:114587ms step_avg:97.69ms +step:1174/1670 train_time:114683ms step_avg:97.69ms +step:1175/1670 train_time:114780ms step_avg:97.69ms +step:1176/1670 train_time:114878ms step_avg:97.69ms +step:1177/1670 train_time:114977ms step_avg:97.69ms +step:1178/1670 train_time:115077ms step_avg:97.69ms +step:1179/1670 train_time:115177ms step_avg:97.69ms +step:1180/1670 train_time:115275ms step_avg:97.69ms +step:1181/1670 train_time:115374ms step_avg:97.69ms +step:1182/1670 train_time:115471ms step_avg:97.69ms +step:1183/1670 train_time:115569ms step_avg:97.69ms +step:1184/1670 train_time:115667ms step_avg:97.69ms +step:1185/1670 train_time:115765ms step_avg:97.69ms +step:1186/1670 train_time:115862ms step_avg:97.69ms +step:1187/1670 train_time:115959ms step_avg:97.69ms +step:1188/1670 train_time:116057ms step_avg:97.69ms +step:1189/1670 train_time:116155ms step_avg:97.69ms +step:1190/1670 train_time:116254ms step_avg:97.69ms +step:1191/1670 train_time:116353ms step_avg:97.69ms +step:1192/1670 train_time:116452ms step_avg:97.69ms +step:1193/1670 train_time:116550ms step_avg:97.69ms +step:1194/1670 train_time:116648ms step_avg:97.70ms +step:1195/1670 train_time:116745ms step_avg:97.69ms +step:1196/1670 train_time:116842ms step_avg:97.69ms +step:1197/1670 train_time:116940ms step_avg:97.69ms +step:1198/1670 train_time:117038ms step_avg:97.69ms +step:1199/1670 train_time:117135ms step_avg:97.69ms +step:1200/1670 train_time:117234ms step_avg:97.69ms +step:1201/1670 train_time:117331ms step_avg:97.69ms +step:1202/1670 train_time:117430ms step_avg:97.70ms +step:1203/1670 train_time:117527ms step_avg:97.70ms +step:1204/1670 train_time:117624ms step_avg:97.69ms +step:1205/1670 train_time:117722ms step_avg:97.69ms +step:1206/1670 train_time:117819ms step_avg:97.69ms +step:1207/1670 train_time:117917ms step_avg:97.69ms +step:1208/1670 train_time:118016ms step_avg:97.70ms +step:1209/1670 train_time:118113ms step_avg:97.70ms +step:1210/1670 train_time:118211ms step_avg:97.70ms +step:1211/1670 train_time:118309ms step_avg:97.70ms +step:1212/1670 train_time:118406ms step_avg:97.70ms +step:1213/1670 train_time:118504ms step_avg:97.70ms +step:1214/1670 train_time:118601ms step_avg:97.69ms +step:1215/1670 train_time:118699ms step_avg:97.69ms +step:1216/1670 train_time:118797ms step_avg:97.69ms +step:1217/1670 train_time:118894ms step_avg:97.69ms +step:1218/1670 train_time:118993ms step_avg:97.70ms +step:1219/1670 train_time:119091ms step_avg:97.70ms +step:1220/1670 train_time:119190ms step_avg:97.70ms +step:1221/1670 train_time:119288ms step_avg:97.70ms +step:1222/1670 train_time:119385ms step_avg:97.70ms +step:1223/1670 train_time:119482ms step_avg:97.70ms +step:1224/1670 train_time:119580ms step_avg:97.70ms +step:1225/1670 train_time:119677ms step_avg:97.70ms +step:1226/1670 train_time:119774ms step_avg:97.70ms +step:1227/1670 train_time:119873ms step_avg:97.70ms +step:1228/1670 train_time:119971ms step_avg:97.70ms +step:1229/1670 train_time:120068ms step_avg:97.70ms +step:1230/1670 train_time:120165ms step_avg:97.70ms +step:1231/1670 train_time:120263ms step_avg:97.70ms +step:1232/1670 train_time:120360ms step_avg:97.69ms +step:1233/1670 train_time:120458ms step_avg:97.69ms +step:1234/1670 train_time:120557ms step_avg:97.70ms +step:1235/1670 train_time:120655ms step_avg:97.70ms +step:1236/1670 train_time:120754ms step_avg:97.70ms +step:1237/1670 train_time:120852ms step_avg:97.70ms +step:1238/1670 train_time:120949ms step_avg:97.70ms +step:1239/1670 train_time:121047ms step_avg:97.70ms +step:1240/1670 train_time:121145ms step_avg:97.70ms +step:1241/1670 train_time:121243ms step_avg:97.70ms +step:1242/1670 train_time:121340ms step_avg:97.70ms +step:1243/1670 train_time:121437ms step_avg:97.70ms +step:1244/1670 train_time:121536ms step_avg:97.70ms +step:1245/1670 train_time:121635ms step_avg:97.70ms +step:1246/1670 train_time:121732ms step_avg:97.70ms +step:1247/1670 train_time:121830ms step_avg:97.70ms +step:1248/1670 train_time:121927ms step_avg:97.70ms +step:1249/1670 train_time:122024ms step_avg:97.70ms +step:1250/1670 train_time:122121ms step_avg:97.70ms +step:1250/1670 val_loss:3.3812 train_time:122219ms step_avg:97.77ms +step:1251/1670 train_time:122239ms step_avg:97.71ms +step:1252/1670 train_time:122323ms step_avg:97.70ms +step:1253/1670 train_time:122423ms step_avg:97.70ms +step:1254/1670 train_time:122521ms step_avg:97.70ms +step:1255/1670 train_time:122618ms step_avg:97.70ms +step:1256/1670 train_time:122714ms step_avg:97.70ms +step:1257/1670 train_time:122811ms step_avg:97.70ms +step:1258/1670 train_time:122908ms step_avg:97.70ms +step:1259/1670 train_time:123005ms step_avg:97.70ms +step:1260/1670 train_time:123102ms step_avg:97.70ms +step:1261/1670 train_time:123201ms step_avg:97.70ms +step:1262/1670 train_time:123301ms step_avg:97.70ms +step:1263/1670 train_time:123400ms step_avg:97.70ms +step:1264/1670 train_time:123498ms step_avg:97.70ms +step:1265/1670 train_time:123596ms step_avg:97.70ms +step:1266/1670 train_time:123692ms step_avg:97.70ms +step:1267/1670 train_time:123789ms step_avg:97.70ms +step:1268/1670 train_time:123886ms step_avg:97.70ms +step:1269/1670 train_time:123984ms step_avg:97.70ms +step:1270/1670 train_time:124082ms step_avg:97.70ms +step:1271/1670 train_time:124179ms step_avg:97.70ms +step:1272/1670 train_time:124278ms step_avg:97.70ms +step:1273/1670 train_time:124376ms step_avg:97.70ms +step:1274/1670 train_time:124644ms step_avg:97.84ms +step:1275/1670 train_time:124845ms step_avg:97.92ms +step:1276/1670 train_time:124940ms step_avg:97.92ms +step:1277/1670 train_time:125037ms step_avg:97.91ms +step:1278/1670 train_time:125133ms step_avg:97.91ms +step:1279/1670 train_time:125229ms step_avg:97.91ms +step:1280/1670 train_time:125327ms step_avg:97.91ms +step:1281/1670 train_time:125424ms step_avg:97.91ms +step:1282/1670 train_time:125521ms step_avg:97.91ms +step:1283/1670 train_time:125617ms step_avg:97.91ms +step:1284/1670 train_time:125719ms step_avg:97.91ms +step:1285/1670 train_time:125818ms step_avg:97.91ms +step:1286/1670 train_time:125916ms step_avg:97.91ms +step:1287/1670 train_time:126013ms step_avg:97.91ms +step:1288/1670 train_time:126110ms step_avg:97.91ms +step:1289/1670 train_time:126209ms step_avg:97.91ms +step:1290/1670 train_time:126306ms step_avg:97.91ms +step:1291/1670 train_time:126403ms step_avg:97.91ms +step:1292/1670 train_time:126500ms step_avg:97.91ms +step:1293/1670 train_time:126596ms step_avg:97.91ms +step:1294/1670 train_time:126694ms step_avg:97.91ms +step:1295/1670 train_time:126794ms step_avg:97.91ms +step:1296/1670 train_time:126892ms step_avg:97.91ms +step:1297/1670 train_time:126992ms step_avg:97.91ms +step:1298/1670 train_time:127089ms step_avg:97.91ms +step:1299/1670 train_time:127188ms step_avg:97.91ms +step:1300/1670 train_time:127285ms step_avg:97.91ms +step:1301/1670 train_time:127383ms step_avg:97.91ms +step:1302/1670 train_time:127479ms step_avg:97.91ms +step:1303/1670 train_time:127576ms step_avg:97.91ms +step:1304/1670 train_time:127674ms step_avg:97.91ms +step:1305/1670 train_time:127772ms step_avg:97.91ms +step:1306/1670 train_time:127871ms step_avg:97.91ms +step:1307/1670 train_time:127970ms step_avg:97.91ms +step:1308/1670 train_time:128069ms step_avg:97.91ms +step:1309/1670 train_time:128167ms step_avg:97.91ms +step:1310/1670 train_time:128265ms step_avg:97.91ms +step:1311/1670 train_time:128363ms step_avg:97.91ms +step:1312/1670 train_time:128460ms step_avg:97.91ms +step:1313/1670 train_time:128557ms step_avg:97.91ms +step:1314/1670 train_time:128655ms step_avg:97.91ms +step:1315/1670 train_time:128753ms step_avg:97.91ms +step:1316/1670 train_time:128851ms step_avg:97.91ms +step:1317/1670 train_time:128949ms step_avg:97.91ms +step:1318/1670 train_time:129049ms step_avg:97.91ms +step:1319/1670 train_time:129148ms step_avg:97.91ms +step:1320/1670 train_time:129246ms step_avg:97.91ms +step:1321/1670 train_time:129345ms step_avg:97.91ms +step:1322/1670 train_time:129443ms step_avg:97.91ms +step:1323/1670 train_time:129541ms step_avg:97.91ms +step:1324/1670 train_time:129638ms step_avg:97.91ms +step:1325/1670 train_time:129736ms step_avg:97.91ms +step:1326/1670 train_time:129834ms step_avg:97.91ms +step:1327/1670 train_time:129933ms step_avg:97.91ms +step:1328/1670 train_time:130030ms step_avg:97.91ms +step:1329/1670 train_time:130127ms step_avg:97.91ms +step:1330/1670 train_time:130225ms step_avg:97.91ms +step:1331/1670 train_time:130323ms step_avg:97.91ms +step:1332/1670 train_time:130422ms step_avg:97.91ms +step:1333/1670 train_time:130519ms step_avg:97.91ms +step:1334/1670 train_time:130617ms step_avg:97.91ms +step:1335/1670 train_time:130714ms step_avg:97.91ms +step:1336/1670 train_time:130812ms step_avg:97.91ms +step:1337/1670 train_time:130910ms step_avg:97.91ms +step:1338/1670 train_time:131008ms step_avg:97.91ms +step:1339/1670 train_time:131106ms step_avg:97.91ms +step:1340/1670 train_time:131204ms step_avg:97.91ms +step:1341/1670 train_time:131302ms step_avg:97.91ms +step:1342/1670 train_time:131401ms step_avg:97.91ms +step:1343/1670 train_time:131498ms step_avg:97.91ms +step:1344/1670 train_time:131594ms step_avg:97.91ms +step:1345/1670 train_time:131692ms step_avg:97.91ms +step:1346/1670 train_time:131790ms step_avg:97.91ms +step:1347/1670 train_time:131889ms step_avg:97.91ms +step:1348/1670 train_time:131988ms step_avg:97.91ms +step:1349/1670 train_time:132085ms step_avg:97.91ms +step:1350/1670 train_time:132183ms step_avg:97.91ms +step:1351/1670 train_time:132281ms step_avg:97.91ms +step:1352/1670 train_time:132378ms step_avg:97.91ms +step:1353/1670 train_time:132476ms step_avg:97.91ms +step:1354/1670 train_time:132573ms step_avg:97.91ms +step:1355/1670 train_time:132670ms step_avg:97.91ms +step:1356/1670 train_time:132768ms step_avg:97.91ms +step:1357/1670 train_time:132867ms step_avg:97.91ms +step:1358/1670 train_time:132966ms step_avg:97.91ms +step:1359/1670 train_time:133063ms step_avg:97.91ms +step:1360/1670 train_time:133161ms step_avg:97.91ms +step:1361/1670 train_time:133258ms step_avg:97.91ms +step:1362/1670 train_time:133356ms step_avg:97.91ms +step:1363/1670 train_time:133453ms step_avg:97.91ms +step:1364/1670 train_time:133550ms step_avg:97.91ms +step:1365/1670 train_time:133649ms step_avg:97.91ms +step:1366/1670 train_time:133747ms step_avg:97.91ms +step:1367/1670 train_time:133846ms step_avg:97.91ms +step:1368/1670 train_time:133943ms step_avg:97.91ms +step:1369/1670 train_time:134043ms step_avg:97.91ms +step:1370/1670 train_time:134141ms step_avg:97.91ms +step:1371/1670 train_time:134239ms step_avg:97.91ms +step:1372/1670 train_time:134336ms step_avg:97.91ms +step:1373/1670 train_time:134434ms step_avg:97.91ms +step:1374/1670 train_time:134531ms step_avg:97.91ms +step:1375/1670 train_time:134630ms step_avg:97.91ms +step:1375/1670 val_loss:3.3441 train_time:134727ms step_avg:97.98ms +step:1376/1670 train_time:134749ms step_avg:97.93ms +step:1377/1670 train_time:134833ms step_avg:97.92ms +step:1378/1670 train_time:134933ms step_avg:97.92ms +step:1379/1670 train_time:135030ms step_avg:97.92ms +step:1380/1670 train_time:135127ms step_avg:97.92ms +step:1381/1670 train_time:135223ms step_avg:97.92ms +step:1382/1670 train_time:135320ms step_avg:97.92ms +step:1383/1670 train_time:135416ms step_avg:97.91ms +step:1384/1670 train_time:135514ms step_avg:97.91ms +step:1385/1670 train_time:135612ms step_avg:97.91ms +step:1386/1670 train_time:135712ms step_avg:97.92ms +step:1387/1670 train_time:135813ms step_avg:97.92ms +step:1388/1670 train_time:135913ms step_avg:97.92ms +step:1389/1670 train_time:136011ms step_avg:97.92ms +step:1390/1670 train_time:136108ms step_avg:97.92ms +step:1391/1670 train_time:136205ms step_avg:97.92ms +step:1392/1670 train_time:136301ms step_avg:97.92ms +step:1393/1670 train_time:136398ms step_avg:97.92ms +step:1394/1670 train_time:136495ms step_avg:97.92ms +step:1395/1670 train_time:136592ms step_avg:97.92ms +step:1396/1670 train_time:136691ms step_avg:97.92ms +step:1397/1670 train_time:136790ms step_avg:97.92ms +step:1398/1670 train_time:136890ms step_avg:97.92ms +step:1399/1670 train_time:136989ms step_avg:97.92ms +step:1400/1670 train_time:137087ms step_avg:97.92ms +step:1401/1670 train_time:137186ms step_avg:97.92ms +step:1402/1670 train_time:137283ms step_avg:97.92ms +step:1403/1670 train_time:137380ms step_avg:97.92ms +step:1404/1670 train_time:137476ms step_avg:97.92ms +step:1405/1670 train_time:137573ms step_avg:97.92ms +step:1406/1670 train_time:137671ms step_avg:97.92ms +step:1407/1670 train_time:137770ms step_avg:97.92ms +step:1408/1670 train_time:137869ms step_avg:97.92ms +step:1409/1670 train_time:137967ms step_avg:97.92ms +step:1410/1670 train_time:138064ms step_avg:97.92ms +step:1411/1670 train_time:138162ms step_avg:97.92ms +step:1412/1670 train_time:138260ms step_avg:97.92ms +step:1413/1670 train_time:138358ms step_avg:97.92ms +step:1414/1670 train_time:138455ms step_avg:97.92ms +step:1415/1670 train_time:138552ms step_avg:97.92ms +step:1416/1670 train_time:138649ms step_avg:97.92ms +step:1417/1670 train_time:138747ms step_avg:97.92ms +step:1418/1670 train_time:138844ms step_avg:97.92ms +step:1419/1670 train_time:138942ms step_avg:97.92ms +step:1420/1670 train_time:139040ms step_avg:97.92ms +step:1421/1670 train_time:139139ms step_avg:97.92ms +step:1422/1670 train_time:139238ms step_avg:97.92ms +step:1423/1670 train_time:139336ms step_avg:97.92ms +step:1424/1670 train_time:139433ms step_avg:97.92ms +step:1425/1670 train_time:139530ms step_avg:97.92ms +step:1426/1670 train_time:139628ms step_avg:97.92ms +step:1427/1670 train_time:139725ms step_avg:97.92ms +step:1428/1670 train_time:139823ms step_avg:97.92ms +step:1429/1670 train_time:139920ms step_avg:97.91ms +step:1430/1670 train_time:140018ms step_avg:97.91ms +step:1431/1670 train_time:140117ms step_avg:97.92ms +step:1432/1670 train_time:140215ms step_avg:97.92ms +step:1433/1670 train_time:140314ms step_avg:97.92ms +step:1434/1670 train_time:140411ms step_avg:97.92ms +step:1435/1670 train_time:140509ms step_avg:97.92ms +step:1436/1670 train_time:140605ms step_avg:97.91ms +step:1437/1670 train_time:140703ms step_avg:97.91ms +step:1438/1670 train_time:140800ms step_avg:97.91ms +step:1439/1670 train_time:140899ms step_avg:97.91ms +step:1440/1670 train_time:140997ms step_avg:97.91ms +step:1441/1670 train_time:141096ms step_avg:97.92ms +step:1442/1670 train_time:141195ms step_avg:97.92ms +step:1443/1670 train_time:141292ms step_avg:97.92ms +step:1444/1670 train_time:141391ms step_avg:97.92ms +step:1445/1670 train_time:141488ms step_avg:97.92ms +step:1446/1670 train_time:141586ms step_avg:97.92ms +step:1447/1670 train_time:141683ms step_avg:97.92ms +step:1448/1670 train_time:141781ms step_avg:97.92ms +step:1449/1670 train_time:141879ms step_avg:97.91ms +step:1450/1670 train_time:141976ms step_avg:97.91ms +step:1451/1670 train_time:142073ms step_avg:97.91ms +step:1452/1670 train_time:142172ms step_avg:97.91ms +step:1453/1670 train_time:142270ms step_avg:97.91ms +step:1454/1670 train_time:142368ms step_avg:97.91ms +step:1455/1670 train_time:142464ms step_avg:97.91ms +step:1456/1670 train_time:142562ms step_avg:97.91ms +step:1457/1670 train_time:142659ms step_avg:97.91ms +step:1458/1670 train_time:142757ms step_avg:97.91ms +step:1459/1670 train_time:142855ms step_avg:97.91ms +step:1460/1670 train_time:142954ms step_avg:97.91ms +step:1461/1670 train_time:143052ms step_avg:97.91ms +step:1462/1670 train_time:143151ms step_avg:97.91ms +step:1463/1670 train_time:143249ms step_avg:97.91ms +step:1464/1670 train_time:143347ms step_avg:97.91ms +step:1465/1670 train_time:143444ms step_avg:97.91ms +step:1466/1670 train_time:143541ms step_avg:97.91ms +step:1467/1670 train_time:143639ms step_avg:97.91ms +step:1468/1670 train_time:143737ms step_avg:97.91ms +step:1469/1670 train_time:143837ms step_avg:97.91ms +step:1470/1670 train_time:143935ms step_avg:97.91ms +step:1471/1670 train_time:144033ms step_avg:97.92ms +step:1472/1670 train_time:144133ms step_avg:97.92ms +step:1473/1670 train_time:144231ms step_avg:97.92ms +step:1474/1670 train_time:144329ms step_avg:97.92ms +step:1475/1670 train_time:144427ms step_avg:97.92ms +step:1476/1670 train_time:144525ms step_avg:97.92ms +step:1477/1670 train_time:144623ms step_avg:97.92ms +step:1478/1670 train_time:144720ms step_avg:97.92ms +step:1479/1670 train_time:144817ms step_avg:97.92ms +step:1480/1670 train_time:144915ms step_avg:97.92ms +step:1481/1670 train_time:145012ms step_avg:97.92ms +step:1482/1670 train_time:145110ms step_avg:97.91ms +step:1483/1670 train_time:145208ms step_avg:97.91ms +step:1484/1670 train_time:145305ms step_avg:97.91ms +step:1485/1670 train_time:145583ms step_avg:98.04ms +step:1486/1670 train_time:145760ms step_avg:98.09ms +step:1487/1670 train_time:145855ms step_avg:98.09ms +step:1488/1670 train_time:145951ms step_avg:98.09ms +step:1489/1670 train_time:146048ms step_avg:98.08ms +step:1490/1670 train_time:146144ms step_avg:98.08ms +step:1491/1670 train_time:146241ms step_avg:98.08ms +step:1492/1670 train_time:146337ms step_avg:98.08ms +step:1493/1670 train_time:146434ms step_avg:98.08ms +step:1494/1670 train_time:146531ms step_avg:98.08ms +step:1495/1670 train_time:146632ms step_avg:98.08ms +step:1496/1670 train_time:146734ms step_avg:98.08ms +step:1497/1670 train_time:146835ms step_avg:98.09ms +step:1498/1670 train_time:146934ms step_avg:98.09ms +step:1499/1670 train_time:147033ms step_avg:98.09ms +step:1500/1670 train_time:147132ms step_avg:98.09ms +step:1500/1670 val_loss:3.3122 train_time:147230ms step_avg:98.15ms +step:1501/1670 train_time:147251ms step_avg:98.10ms +step:1502/1670 train_time:147334ms step_avg:98.09ms +step:1503/1670 train_time:147433ms step_avg:98.09ms +step:1504/1670 train_time:147530ms step_avg:98.09ms +step:1505/1670 train_time:147628ms step_avg:98.09ms +step:1506/1670 train_time:147725ms step_avg:98.09ms +step:1507/1670 train_time:147822ms step_avg:98.09ms +step:1508/1670 train_time:147920ms step_avg:98.09ms +step:1509/1670 train_time:148017ms step_avg:98.09ms +step:1510/1670 train_time:148114ms step_avg:98.09ms +step:1511/1670 train_time:148213ms step_avg:98.09ms +step:1512/1670 train_time:148311ms step_avg:98.09ms +step:1513/1670 train_time:148409ms step_avg:98.09ms +step:1514/1670 train_time:148508ms step_avg:98.09ms +step:1515/1670 train_time:148605ms step_avg:98.09ms +step:1516/1670 train_time:148702ms step_avg:98.09ms +step:1517/1670 train_time:148801ms step_avg:98.09ms +step:1518/1670 train_time:148899ms step_avg:98.09ms +step:1519/1670 train_time:148995ms step_avg:98.09ms +step:1520/1670 train_time:149092ms step_avg:98.09ms +step:1521/1670 train_time:149190ms step_avg:98.09ms +step:1522/1670 train_time:149289ms step_avg:98.09ms +step:1523/1670 train_time:149387ms step_avg:98.09ms +step:1524/1670 train_time:149485ms step_avg:98.09ms +step:1525/1670 train_time:149583ms step_avg:98.09ms +step:1526/1670 train_time:149681ms step_avg:98.09ms +step:1527/1670 train_time:149778ms step_avg:98.09ms +step:1528/1670 train_time:149875ms step_avg:98.09ms +step:1529/1670 train_time:149972ms step_avg:98.09ms +step:1530/1670 train_time:150069ms step_avg:98.08ms +step:1531/1670 train_time:150166ms step_avg:98.08ms +step:1532/1670 train_time:150266ms step_avg:98.08ms +step:1533/1670 train_time:150365ms step_avg:98.09ms +step:1534/1670 train_time:150465ms step_avg:98.09ms +step:1535/1670 train_time:150563ms step_avg:98.09ms +step:1536/1670 train_time:150661ms step_avg:98.09ms +step:1537/1670 train_time:150758ms step_avg:98.09ms +step:1538/1670 train_time:150855ms step_avg:98.09ms +step:1539/1670 train_time:150952ms step_avg:98.08ms +step:1540/1670 train_time:151049ms step_avg:98.08ms +step:1541/1670 train_time:151146ms step_avg:98.08ms +step:1542/1670 train_time:151244ms step_avg:98.08ms +step:1543/1670 train_time:151342ms step_avg:98.08ms +step:1544/1670 train_time:151441ms step_avg:98.08ms +step:1545/1670 train_time:151540ms step_avg:98.08ms +step:1546/1670 train_time:151638ms step_avg:98.08ms +step:1547/1670 train_time:151736ms step_avg:98.08ms +step:1548/1670 train_time:151833ms step_avg:98.08ms +step:1549/1670 train_time:151931ms step_avg:98.08ms +step:1550/1670 train_time:152027ms step_avg:98.08ms +step:1551/1670 train_time:152126ms step_avg:98.08ms +step:1552/1670 train_time:152224ms step_avg:98.08ms +step:1553/1670 train_time:152323ms step_avg:98.08ms +step:1554/1670 train_time:152421ms step_avg:98.08ms +step:1555/1670 train_time:152519ms step_avg:98.08ms +step:1556/1670 train_time:152617ms step_avg:98.08ms +step:1557/1670 train_time:152714ms step_avg:98.08ms +step:1558/1670 train_time:152812ms step_avg:98.08ms +step:1559/1670 train_time:152908ms step_avg:98.08ms +step:1560/1670 train_time:153007ms step_avg:98.08ms +step:1561/1670 train_time:153104ms step_avg:98.08ms +step:1562/1670 train_time:153202ms step_avg:98.08ms +step:1563/1670 train_time:153302ms step_avg:98.08ms +step:1564/1670 train_time:153401ms step_avg:98.08ms +step:1565/1670 train_time:153499ms step_avg:98.08ms +step:1566/1670 train_time:153597ms step_avg:98.08ms +step:1567/1670 train_time:153694ms step_avg:98.08ms +step:1568/1670 train_time:153792ms step_avg:98.08ms +step:1569/1670 train_time:153889ms step_avg:98.08ms +step:1570/1670 train_time:153987ms step_avg:98.08ms +step:1571/1670 train_time:154084ms step_avg:98.08ms +step:1572/1670 train_time:154182ms step_avg:98.08ms +step:1573/1670 train_time:154279ms step_avg:98.08ms +step:1574/1670 train_time:154377ms step_avg:98.08ms +step:1575/1670 train_time:154474ms step_avg:98.08ms +step:1576/1670 train_time:154571ms step_avg:98.08ms +step:1577/1670 train_time:154668ms step_avg:98.08ms +step:1578/1670 train_time:154767ms step_avg:98.08ms +step:1579/1670 train_time:154865ms step_avg:98.08ms +step:1580/1670 train_time:154964ms step_avg:98.08ms +step:1581/1670 train_time:155062ms step_avg:98.08ms +step:1582/1670 train_time:155161ms step_avg:98.08ms +step:1583/1670 train_time:155258ms step_avg:98.08ms +step:1584/1670 train_time:155355ms step_avg:98.08ms +step:1585/1670 train_time:155453ms step_avg:98.08ms +step:1586/1670 train_time:155550ms step_avg:98.08ms +step:1587/1670 train_time:155648ms step_avg:98.08ms +step:1588/1670 train_time:155745ms step_avg:98.08ms +step:1589/1670 train_time:155843ms step_avg:98.08ms +step:1590/1670 train_time:155942ms step_avg:98.08ms +step:1591/1670 train_time:156040ms step_avg:98.08ms +step:1592/1670 train_time:156138ms step_avg:98.08ms +step:1593/1670 train_time:156235ms step_avg:98.08ms +step:1594/1670 train_time:156332ms step_avg:98.08ms +step:1595/1670 train_time:156431ms step_avg:98.08ms +step:1596/1670 train_time:156528ms step_avg:98.08ms +step:1597/1670 train_time:156626ms step_avg:98.07ms +step:1598/1670 train_time:156724ms step_avg:98.08ms +step:1599/1670 train_time:156823ms step_avg:98.08ms +step:1600/1670 train_time:156921ms step_avg:98.08ms +step:1601/1670 train_time:157019ms step_avg:98.08ms +step:1602/1670 train_time:157116ms step_avg:98.07ms +step:1603/1670 train_time:157214ms step_avg:98.07ms +step:1604/1670 train_time:157311ms step_avg:98.07ms +step:1605/1670 train_time:157408ms step_avg:98.07ms +step:1606/1670 train_time:157505ms step_avg:98.07ms +step:1607/1670 train_time:157604ms step_avg:98.07ms +step:1608/1670 train_time:157702ms step_avg:98.07ms +step:1609/1670 train_time:157800ms step_avg:98.07ms +step:1610/1670 train_time:157898ms step_avg:98.07ms +step:1611/1670 train_time:157996ms step_avg:98.07ms +step:1612/1670 train_time:158093ms step_avg:98.07ms +step:1613/1670 train_time:158191ms step_avg:98.07ms +step:1614/1670 train_time:158288ms step_avg:98.07ms +step:1615/1670 train_time:158387ms step_avg:98.07ms +step:1616/1670 train_time:158485ms step_avg:98.07ms +step:1617/1670 train_time:158582ms step_avg:98.07ms +step:1618/1670 train_time:158680ms step_avg:98.07ms +step:1619/1670 train_time:158779ms step_avg:98.07ms +step:1620/1670 train_time:158876ms step_avg:98.07ms +step:1621/1670 train_time:158973ms step_avg:98.07ms +step:1622/1670 train_time:159071ms step_avg:98.07ms +step:1623/1670 train_time:159168ms step_avg:98.07ms +step:1624/1670 train_time:159266ms step_avg:98.07ms +step:1625/1670 train_time:159364ms step_avg:98.07ms +step:1625/1670 val_loss:3.2853 train_time:159462ms step_avg:98.13ms +step:1626/1670 train_time:159483ms step_avg:98.08ms +step:1627/1670 train_time:159567ms step_avg:98.07ms +step:1628/1670 train_time:159668ms step_avg:98.08ms +step:1629/1670 train_time:159766ms step_avg:98.08ms +step:1630/1670 train_time:159863ms step_avg:98.08ms +step:1631/1670 train_time:159959ms step_avg:98.07ms +step:1632/1670 train_time:160056ms step_avg:98.07ms +step:1633/1670 train_time:160153ms step_avg:98.07ms +step:1634/1670 train_time:160251ms step_avg:98.07ms +step:1635/1670 train_time:160348ms step_avg:98.07ms +step:1636/1670 train_time:160447ms step_avg:98.07ms +step:1637/1670 train_time:160547ms step_avg:98.07ms +step:1638/1670 train_time:160647ms step_avg:98.08ms +step:1639/1670 train_time:160746ms step_avg:98.08ms +step:1640/1670 train_time:160844ms step_avg:98.08ms +step:1641/1670 train_time:160941ms step_avg:98.07ms +step:1642/1670 train_time:161038ms step_avg:98.07ms +step:1643/1670 train_time:161135ms step_avg:98.07ms +step:1644/1670 train_time:161232ms step_avg:98.07ms +step:1645/1670 train_time:161329ms step_avg:98.07ms +step:1646/1670 train_time:161427ms step_avg:98.07ms +step:1647/1670 train_time:161526ms step_avg:98.07ms +step:1648/1670 train_time:161625ms step_avg:98.07ms +step:1649/1670 train_time:161723ms step_avg:98.07ms +step:1650/1670 train_time:161821ms step_avg:98.07ms +step:1651/1670 train_time:161918ms step_avg:98.07ms +step:1652/1670 train_time:162015ms step_avg:98.07ms +step:1653/1670 train_time:162113ms step_avg:98.07ms +step:1654/1670 train_time:162210ms step_avg:98.07ms +step:1655/1670 train_time:162308ms step_avg:98.07ms +step:1656/1670 train_time:162406ms step_avg:98.07ms +step:1657/1670 train_time:162504ms step_avg:98.07ms +step:1658/1670 train_time:162603ms step_avg:98.07ms +step:1659/1670 train_time:162701ms step_avg:98.07ms +step:1660/1670 train_time:162798ms step_avg:98.07ms +step:1661/1670 train_time:162896ms step_avg:98.07ms +step:1662/1670 train_time:162994ms step_avg:98.07ms +step:1663/1670 train_time:163091ms step_avg:98.07ms +step:1664/1670 train_time:163189ms step_avg:98.07ms +step:1665/1670 train_time:163286ms step_avg:98.07ms +step:1666/1670 train_time:163384ms step_avg:98.07ms +step:1667/1670 train_time:163482ms step_avg:98.07ms +step:1668/1670 train_time:163579ms step_avg:98.07ms +step:1669/1670 train_time:163676ms step_avg:98.07ms +step:1670/1670 train_time:163774ms step_avg:98.07ms +step:1670/1670 val_loss:3.2771 train_time:163871ms step_avg:98.13ms +peak memory allocated: 34361 MiB reserved: 49276 MiB diff --git a/records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt b/records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt new file mode 100644 index 000000000..80e318320 --- /dev/null +++ b/records/090325_FA3/65b0d9c0-3089-40eb-a1bc-45b15f897462.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:44:22 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 37C P0 119W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 37C P0 128W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 36C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 36C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 49666 C /usr/bin/python 0MiB | +| 0 N/A N/A 49667 C /usr/bin/python 0MiB | +| 0 N/A N/A 49668 C /usr/bin/python 0MiB | +| 0 N/A N/A 49669 C /usr/bin/python 0MiB | +| 0 N/A N/A 49670 C /usr/bin/python 0MiB | +| 0 N/A N/A 49671 C /usr/bin/python 0MiB | +| 0 N/A N/A 49672 C /usr/bin/python 0MiB | +| 0 N/A N/A 49673 C /usr/bin/python 0MiB | +| 1 N/A N/A 49667 C /usr/bin/python 0MiB | +| 2 N/A N/A 49668 C /usr/bin/python 0MiB | +| 3 N/A N/A 49669 C /usr/bin/python 0MiB | +| 4 N/A N/A 49670 C /usr/bin/python 0MiB | +| 5 N/A N/A 49671 C /usr/bin/python 0MiB | +| 6 N/A N/A 49672 C /usr/bin/python 0MiB | +| 7 N/A N/A 49673 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:380ms step_avg:379.58ms +step:2/1670 train_time:400ms step_avg:200.05ms +step:3/1670 train_time:473ms step_avg:157.71ms +step:4/1670 train_time:567ms step_avg:141.68ms +step:5/1670 train_time:661ms step_avg:132.24ms +step:6/1670 train_time:756ms step_avg:125.95ms +step:7/1670 train_time:850ms step_avg:121.42ms +step:8/1670 train_time:945ms step_avg:118.11ms +step:9/1670 train_time:1040ms step_avg:115.55ms +step:10/1670 train_time:1135ms step_avg:113.51ms +step:11/1670 train_time:1230ms step_avg:111.83ms +step:12/1670 train_time:1327ms step_avg:110.61ms +step:13/1670 train_time:1426ms step_avg:109.69ms +step:14/1670 train_time:1524ms step_avg:108.84ms +step:15/1670 train_time:1620ms step_avg:107.99ms +step:16/1670 train_time:1715ms step_avg:107.20ms +step:17/1670 train_time:1811ms step_avg:106.51ms +step:18/1670 train_time:1906ms step_avg:105.87ms +step:19/1670 train_time:2001ms step_avg:105.33ms +step:20/1670 train_time:2097ms step_avg:104.83ms +step:21/1670 train_time:2192ms step_avg:104.38ms +step:22/1670 train_time:2288ms step_avg:103.99ms +step:23/1670 train_time:2385ms step_avg:103.68ms +step:24/1670 train_time:2482ms step_avg:103.43ms +step:25/1670 train_time:2579ms step_avg:103.15ms +step:26/1670 train_time:2675ms step_avg:102.88ms +step:27/1670 train_time:2770ms step_avg:102.60ms +step:28/1670 train_time:2865ms step_avg:102.34ms +step:29/1670 train_time:2961ms step_avg:102.09ms +step:30/1670 train_time:3056ms step_avg:101.86ms +step:31/1670 train_time:3151ms step_avg:101.65ms +step:32/1670 train_time:3247ms step_avg:101.47ms +step:33/1670 train_time:3344ms step_avg:101.32ms +step:34/1670 train_time:3440ms step_avg:101.18ms +step:35/1670 train_time:3537ms step_avg:101.05ms +step:36/1670 train_time:3633ms step_avg:100.92ms +step:37/1670 train_time:3729ms step_avg:100.80ms +step:38/1670 train_time:3825ms step_avg:100.65ms +step:39/1670 train_time:3920ms step_avg:100.51ms +step:40/1670 train_time:4016ms step_avg:100.40ms +step:41/1670 train_time:4111ms step_avg:100.27ms +step:42/1670 train_time:4206ms step_avg:100.15ms +step:43/1670 train_time:4302ms step_avg:100.05ms +step:44/1670 train_time:4398ms step_avg:99.96ms +step:45/1670 train_time:4495ms step_avg:99.89ms +step:46/1670 train_time:4591ms step_avg:99.80ms +step:47/1670 train_time:4687ms step_avg:99.72ms +step:48/1670 train_time:4782ms step_avg:99.63ms +step:49/1670 train_time:4879ms step_avg:99.57ms +step:50/1670 train_time:4974ms step_avg:99.48ms +step:51/1670 train_time:5069ms step_avg:99.40ms +step:52/1670 train_time:5165ms step_avg:99.32ms +step:53/1670 train_time:5260ms step_avg:99.25ms +step:54/1670 train_time:5356ms step_avg:99.19ms +step:55/1670 train_time:5452ms step_avg:99.13ms +step:56/1670 train_time:5547ms step_avg:99.06ms +step:57/1670 train_time:5643ms step_avg:99.01ms +step:58/1670 train_time:5741ms step_avg:98.98ms +step:59/1670 train_time:5837ms step_avg:98.93ms +step:60/1670 train_time:5933ms step_avg:98.88ms +step:61/1670 train_time:6028ms step_avg:98.82ms +step:62/1670 train_time:6124ms step_avg:98.77ms +step:63/1670 train_time:6220ms step_avg:98.73ms +step:64/1670 train_time:6316ms step_avg:98.69ms +step:65/1670 train_time:6412ms step_avg:98.64ms +step:66/1670 train_time:6507ms step_avg:98.60ms +step:67/1670 train_time:6604ms step_avg:98.56ms +step:68/1670 train_time:6700ms step_avg:98.53ms +step:69/1670 train_time:6797ms step_avg:98.50ms +step:70/1670 train_time:6892ms step_avg:98.46ms +step:71/1670 train_time:6988ms step_avg:98.42ms +step:72/1670 train_time:7084ms step_avg:98.39ms +step:73/1670 train_time:7180ms step_avg:98.35ms +step:74/1670 train_time:7276ms step_avg:98.32ms +step:75/1670 train_time:7371ms step_avg:98.28ms +step:76/1670 train_time:7466ms step_avg:98.24ms +step:77/1670 train_time:7563ms step_avg:98.21ms +step:78/1670 train_time:7658ms step_avg:98.18ms +step:79/1670 train_time:7754ms step_avg:98.16ms +step:80/1670 train_time:7850ms step_avg:98.12ms +step:81/1670 train_time:7945ms step_avg:98.09ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8137ms step_avg:98.04ms +step:84/1670 train_time:8233ms step_avg:98.01ms +step:85/1670 train_time:8329ms step_avg:97.99ms +step:86/1670 train_time:8425ms step_avg:97.96ms +step:87/1670 train_time:8522ms step_avg:97.96ms +step:88/1670 train_time:8618ms step_avg:97.94ms +step:89/1670 train_time:8714ms step_avg:97.91ms +step:90/1670 train_time:8810ms step_avg:97.89ms +step:91/1670 train_time:8905ms step_avg:97.86ms +step:92/1670 train_time:9001ms step_avg:97.83ms +step:93/1670 train_time:9097ms step_avg:97.82ms +step:94/1670 train_time:9192ms step_avg:97.79ms +step:95/1670 train_time:9288ms step_avg:97.77ms +step:96/1670 train_time:9383ms step_avg:97.74ms +step:97/1670 train_time:9479ms step_avg:97.72ms +step:98/1670 train_time:9575ms step_avg:97.70ms +step:99/1670 train_time:9670ms step_avg:97.68ms +step:100/1670 train_time:9766ms step_avg:97.66ms +step:101/1670 train_time:9863ms step_avg:97.65ms +step:102/1670 train_time:9959ms step_avg:97.64ms +step:103/1670 train_time:10055ms step_avg:97.62ms +step:104/1670 train_time:10151ms step_avg:97.60ms +step:105/1670 train_time:10246ms step_avg:97.58ms +step:106/1670 train_time:10342ms step_avg:97.57ms +step:107/1670 train_time:10438ms step_avg:97.55ms +step:108/1670 train_time:10533ms step_avg:97.53ms +step:109/1670 train_time:10630ms step_avg:97.52ms +step:110/1670 train_time:10725ms step_avg:97.50ms +step:111/1670 train_time:10821ms step_avg:97.49ms +step:112/1670 train_time:10918ms step_avg:97.48ms +step:113/1670 train_time:11013ms step_avg:97.46ms +step:114/1670 train_time:11109ms step_avg:97.44ms +step:115/1670 train_time:11204ms step_avg:97.43ms +step:116/1670 train_time:11300ms step_avg:97.42ms +step:117/1670 train_time:11397ms step_avg:97.41ms +step:118/1670 train_time:11493ms step_avg:97.40ms +step:119/1670 train_time:11588ms step_avg:97.38ms +step:120/1670 train_time:11684ms step_avg:97.37ms +step:121/1670 train_time:11780ms step_avg:97.35ms +step:122/1670 train_time:11875ms step_avg:97.34ms +step:123/1670 train_time:11970ms step_avg:97.32ms +step:124/1670 train_time:12067ms step_avg:97.31ms +step:125/1670 train_time:12163ms step_avg:97.30ms +step:125/1670 val_loss:4.3037 train_time:12257ms step_avg:98.06ms +step:126/1670 train_time:12280ms step_avg:97.46ms +step:127/1670 train_time:12362ms step_avg:97.34ms +step:128/1670 train_time:12466ms step_avg:97.39ms +step:129/1670 train_time:12564ms step_avg:97.40ms +step:130/1670 train_time:12660ms step_avg:97.38ms +step:131/1670 train_time:12755ms step_avg:97.36ms +step:132/1670 train_time:12850ms step_avg:97.34ms +step:133/1670 train_time:12945ms step_avg:97.33ms +step:134/1670 train_time:13040ms step_avg:97.32ms +step:135/1670 train_time:13135ms step_avg:97.30ms +step:136/1670 train_time:13230ms step_avg:97.28ms +step:137/1670 train_time:13328ms step_avg:97.29ms +step:138/1670 train_time:13427ms step_avg:97.30ms +step:139/1670 train_time:13525ms step_avg:97.30ms +step:140/1670 train_time:13622ms step_avg:97.30ms +step:141/1670 train_time:13718ms step_avg:97.29ms +step:142/1670 train_time:13814ms step_avg:97.28ms +step:143/1670 train_time:13908ms step_avg:97.26ms +step:144/1670 train_time:14003ms step_avg:97.24ms +step:145/1670 train_time:14098ms step_avg:97.23ms +step:146/1670 train_time:14193ms step_avg:97.21ms +step:147/1670 train_time:14289ms step_avg:97.20ms +step:148/1670 train_time:14386ms step_avg:97.20ms +step:149/1670 train_time:14483ms step_avg:97.20ms +step:150/1670 train_time:14580ms step_avg:97.20ms +step:151/1670 train_time:14675ms step_avg:97.19ms +step:152/1670 train_time:14771ms step_avg:97.18ms +step:153/1670 train_time:14866ms step_avg:97.16ms +step:154/1670 train_time:14961ms step_avg:97.15ms +step:155/1670 train_time:15056ms step_avg:97.14ms +step:156/1670 train_time:15151ms step_avg:97.12ms +step:157/1670 train_time:15247ms step_avg:97.11ms +step:158/1670 train_time:15342ms step_avg:97.10ms +step:159/1670 train_time:15439ms step_avg:97.10ms +step:160/1670 train_time:15535ms step_avg:97.10ms +step:161/1670 train_time:15631ms step_avg:97.09ms +step:162/1670 train_time:15728ms step_avg:97.09ms +step:163/1670 train_time:15824ms step_avg:97.08ms +step:164/1670 train_time:15919ms step_avg:97.07ms +step:165/1670 train_time:16016ms step_avg:97.06ms +step:166/1670 train_time:16111ms step_avg:97.05ms +step:167/1670 train_time:16206ms step_avg:97.04ms +step:168/1670 train_time:16302ms step_avg:97.03ms +step:169/1670 train_time:16397ms step_avg:97.03ms +step:170/1670 train_time:16493ms step_avg:97.02ms +step:171/1670 train_time:16588ms step_avg:97.00ms +step:172/1670 train_time:16685ms step_avg:97.00ms +step:173/1670 train_time:16780ms step_avg:97.00ms +step:174/1670 train_time:16875ms step_avg:96.99ms +step:175/1670 train_time:16970ms step_avg:96.97ms +step:176/1670 train_time:17066ms step_avg:96.96ms +step:177/1670 train_time:17162ms step_avg:96.96ms +step:178/1670 train_time:17257ms step_avg:96.95ms +step:179/1670 train_time:17352ms step_avg:96.94ms +step:180/1670 train_time:17447ms step_avg:96.93ms +step:181/1670 train_time:17543ms step_avg:96.92ms +step:182/1670 train_time:17640ms step_avg:96.92ms +step:183/1670 train_time:17735ms step_avg:96.91ms +step:184/1670 train_time:17831ms step_avg:96.91ms +step:185/1670 train_time:17927ms step_avg:96.90ms +step:186/1670 train_time:18023ms step_avg:96.90ms +step:187/1670 train_time:18119ms step_avg:96.89ms +step:188/1670 train_time:18213ms step_avg:96.88ms +step:189/1670 train_time:18309ms step_avg:96.88ms +step:190/1670 train_time:18405ms step_avg:96.87ms +step:191/1670 train_time:18500ms step_avg:96.86ms +step:192/1670 train_time:18596ms step_avg:96.85ms +step:193/1670 train_time:18691ms step_avg:96.85ms +step:194/1670 train_time:18787ms step_avg:96.84ms +step:195/1670 train_time:18883ms step_avg:96.84ms +step:196/1670 train_time:18980ms step_avg:96.83ms +step:197/1670 train_time:19075ms step_avg:96.83ms +step:198/1670 train_time:19170ms step_avg:96.82ms +step:199/1670 train_time:19266ms step_avg:96.81ms +step:200/1670 train_time:19362ms step_avg:96.81ms +step:201/1670 train_time:19457ms step_avg:96.80ms +step:202/1670 train_time:19552ms step_avg:96.79ms +step:203/1670 train_time:19648ms step_avg:96.79ms +step:204/1670 train_time:19743ms step_avg:96.78ms +step:205/1670 train_time:19839ms step_avg:96.78ms +step:206/1670 train_time:19934ms step_avg:96.77ms +step:207/1670 train_time:20030ms step_avg:96.77ms +step:208/1670 train_time:20125ms step_avg:96.76ms +step:209/1670 train_time:20221ms step_avg:96.75ms +step:210/1670 train_time:20316ms step_avg:96.74ms +step:211/1670 train_time:20412ms step_avg:96.74ms +step:212/1670 train_time:20507ms step_avg:96.73ms +step:213/1670 train_time:20790ms step_avg:97.61ms +step:214/1670 train_time:20880ms step_avg:97.57ms +step:215/1670 train_time:20974ms step_avg:97.56ms +step:216/1670 train_time:21069ms step_avg:97.54ms +step:217/1670 train_time:21163ms step_avg:97.53ms +step:218/1670 train_time:21258ms step_avg:97.51ms +step:219/1670 train_time:21352ms step_avg:97.50ms +step:220/1670 train_time:21446ms step_avg:97.48ms +step:221/1670 train_time:21541ms step_avg:97.47ms +step:222/1670 train_time:21636ms step_avg:97.46ms +step:223/1670 train_time:21735ms step_avg:97.47ms +step:224/1670 train_time:21834ms step_avg:97.47ms +step:225/1670 train_time:21930ms step_avg:97.47ms +step:226/1670 train_time:22026ms step_avg:97.46ms +step:227/1670 train_time:22121ms step_avg:97.45ms +step:228/1670 train_time:22216ms step_avg:97.44ms +step:229/1670 train_time:22311ms step_avg:97.43ms +step:230/1670 train_time:22406ms step_avg:97.42ms +step:231/1670 train_time:22501ms step_avg:97.41ms +step:232/1670 train_time:22596ms step_avg:97.40ms +step:233/1670 train_time:22692ms step_avg:97.39ms +step:234/1670 train_time:22789ms step_avg:97.39ms +step:235/1670 train_time:22887ms step_avg:97.39ms +step:236/1670 train_time:22984ms step_avg:97.39ms +step:237/1670 train_time:23079ms step_avg:97.38ms +step:238/1670 train_time:23174ms step_avg:97.37ms +step:239/1670 train_time:23269ms step_avg:97.36ms +step:240/1670 train_time:23364ms step_avg:97.35ms +step:241/1670 train_time:23458ms step_avg:97.34ms +step:242/1670 train_time:23553ms step_avg:97.33ms +step:243/1670 train_time:23648ms step_avg:97.32ms +step:244/1670 train_time:23746ms step_avg:97.32ms +step:245/1670 train_time:23842ms step_avg:97.32ms +step:246/1670 train_time:23939ms step_avg:97.31ms +step:247/1670 train_time:24035ms step_avg:97.31ms +step:248/1670 train_time:24130ms step_avg:97.30ms +step:249/1670 train_time:24225ms step_avg:97.29ms +step:250/1670 train_time:24320ms step_avg:97.28ms +step:250/1670 val_loss:3.9734 train_time:24414ms step_avg:97.66ms +step:251/1670 train_time:24435ms step_avg:97.35ms +step:252/1670 train_time:24516ms step_avg:97.29ms +step:253/1670 train_time:24618ms step_avg:97.30ms +step:254/1670 train_time:24714ms step_avg:97.30ms +step:255/1670 train_time:24809ms step_avg:97.29ms +step:256/1670 train_time:24903ms step_avg:97.28ms +step:257/1670 train_time:24998ms step_avg:97.27ms +step:258/1670 train_time:25093ms step_avg:97.26ms +step:259/1670 train_time:25187ms step_avg:97.25ms +step:260/1670 train_time:25282ms step_avg:97.24ms +step:261/1670 train_time:25378ms step_avg:97.24ms +step:262/1670 train_time:25478ms step_avg:97.24ms +step:263/1670 train_time:25576ms step_avg:97.25ms +step:264/1670 train_time:25672ms step_avg:97.24ms +step:265/1670 train_time:25769ms step_avg:97.24ms +step:266/1670 train_time:25864ms step_avg:97.23ms +step:267/1670 train_time:25959ms step_avg:97.22ms +step:268/1670 train_time:26053ms step_avg:97.21ms +step:269/1670 train_time:26148ms step_avg:97.21ms +step:270/1670 train_time:26243ms step_avg:97.20ms +step:271/1670 train_time:26338ms step_avg:97.19ms +step:272/1670 train_time:26435ms step_avg:97.19ms +step:273/1670 train_time:26532ms step_avg:97.19ms +step:274/1670 train_time:26628ms step_avg:97.18ms +step:275/1670 train_time:26724ms step_avg:97.18ms +step:276/1670 train_time:26819ms step_avg:97.17ms +step:277/1670 train_time:26915ms step_avg:97.17ms +step:278/1670 train_time:27010ms step_avg:97.16ms +step:279/1670 train_time:27105ms step_avg:97.15ms +step:280/1670 train_time:27200ms step_avg:97.14ms +step:281/1670 train_time:27295ms step_avg:97.14ms +step:282/1670 train_time:27391ms step_avg:97.13ms +step:283/1670 train_time:27487ms step_avg:97.13ms +step:284/1670 train_time:27582ms step_avg:97.12ms +step:285/1670 train_time:27679ms step_avg:97.12ms +step:286/1670 train_time:27776ms step_avg:97.12ms +step:287/1670 train_time:27871ms step_avg:97.11ms +step:288/1670 train_time:27968ms step_avg:97.11ms +step:289/1670 train_time:28062ms step_avg:97.10ms +step:290/1670 train_time:28158ms step_avg:97.10ms +step:291/1670 train_time:28254ms step_avg:97.09ms +step:292/1670 train_time:28349ms step_avg:97.09ms +step:293/1670 train_time:28445ms step_avg:97.08ms +step:294/1670 train_time:28540ms step_avg:97.08ms +step:295/1670 train_time:28636ms step_avg:97.07ms +step:296/1670 train_time:28733ms step_avg:97.07ms +step:297/1670 train_time:28828ms step_avg:97.06ms +step:298/1670 train_time:28924ms step_avg:97.06ms +step:299/1670 train_time:29019ms step_avg:97.05ms +step:300/1670 train_time:29116ms step_avg:97.05ms +step:301/1670 train_time:29211ms step_avg:97.05ms +step:302/1670 train_time:29306ms step_avg:97.04ms +step:303/1670 train_time:29401ms step_avg:97.03ms +step:304/1670 train_time:29498ms step_avg:97.03ms +step:305/1670 train_time:29594ms step_avg:97.03ms +step:306/1670 train_time:29690ms step_avg:97.03ms +step:307/1670 train_time:29785ms step_avg:97.02ms +step:308/1670 train_time:29880ms step_avg:97.01ms +step:309/1670 train_time:29976ms step_avg:97.01ms +step:310/1670 train_time:30071ms step_avg:97.00ms +step:311/1670 train_time:30167ms step_avg:97.00ms +step:312/1670 train_time:30262ms step_avg:96.99ms +step:313/1670 train_time:30357ms step_avg:96.99ms +step:314/1670 train_time:30453ms step_avg:96.98ms +step:315/1670 train_time:30548ms step_avg:96.98ms +step:316/1670 train_time:30644ms step_avg:96.97ms +step:317/1670 train_time:30740ms step_avg:96.97ms +step:318/1670 train_time:30837ms step_avg:96.97ms +step:319/1670 train_time:30934ms step_avg:96.97ms +step:320/1670 train_time:31030ms step_avg:96.97ms +step:321/1670 train_time:31126ms step_avg:96.96ms +step:322/1670 train_time:31221ms step_avg:96.96ms +step:323/1670 train_time:31318ms step_avg:96.96ms +step:324/1670 train_time:31414ms step_avg:96.96ms +step:325/1670 train_time:31509ms step_avg:96.95ms +step:326/1670 train_time:31604ms step_avg:96.95ms +step:327/1670 train_time:31700ms step_avg:96.94ms +step:328/1670 train_time:31796ms step_avg:96.94ms +step:329/1670 train_time:31892ms step_avg:96.94ms +step:330/1670 train_time:31987ms step_avg:96.93ms +step:331/1670 train_time:32082ms step_avg:96.92ms +step:332/1670 train_time:32178ms step_avg:96.92ms +step:333/1670 train_time:32275ms step_avg:96.92ms +step:334/1670 train_time:32370ms step_avg:96.92ms +step:335/1670 train_time:32465ms step_avg:96.91ms +step:336/1670 train_time:32561ms step_avg:96.91ms +step:337/1670 train_time:32658ms step_avg:96.91ms +step:338/1670 train_time:32754ms step_avg:96.90ms +step:339/1670 train_time:32849ms step_avg:96.90ms +step:340/1670 train_time:32944ms step_avg:96.89ms +step:341/1670 train_time:33040ms step_avg:96.89ms +step:342/1670 train_time:33136ms step_avg:96.89ms +step:343/1670 train_time:33232ms step_avg:96.89ms +step:344/1670 train_time:33328ms step_avg:96.88ms +step:345/1670 train_time:33424ms step_avg:96.88ms +step:346/1670 train_time:33519ms step_avg:96.88ms +step:347/1670 train_time:33616ms step_avg:96.88ms +step:348/1670 train_time:33711ms step_avg:96.87ms +step:349/1670 train_time:33806ms step_avg:96.87ms +step:350/1670 train_time:33901ms step_avg:96.86ms +step:351/1670 train_time:33997ms step_avg:96.86ms +step:352/1670 train_time:34093ms step_avg:96.85ms +step:353/1670 train_time:34188ms step_avg:96.85ms +step:354/1670 train_time:34284ms step_avg:96.85ms +step:355/1670 train_time:34379ms step_avg:96.84ms +step:356/1670 train_time:34475ms step_avg:96.84ms +step:357/1670 train_time:34571ms step_avg:96.84ms +step:358/1670 train_time:34667ms step_avg:96.83ms +step:359/1670 train_time:34762ms step_avg:96.83ms +step:360/1670 train_time:34858ms step_avg:96.83ms +step:361/1670 train_time:34954ms step_avg:96.83ms +step:362/1670 train_time:35050ms step_avg:96.82ms +step:363/1670 train_time:35145ms step_avg:96.82ms +step:364/1670 train_time:35240ms step_avg:96.81ms +step:365/1670 train_time:35337ms step_avg:96.81ms +step:366/1670 train_time:35432ms step_avg:96.81ms +step:367/1670 train_time:35528ms step_avg:96.81ms +step:368/1670 train_time:35624ms step_avg:96.80ms +step:369/1670 train_time:35719ms step_avg:96.80ms +step:370/1670 train_time:35815ms step_avg:96.80ms +step:371/1670 train_time:35911ms step_avg:96.80ms +step:372/1670 train_time:36006ms step_avg:96.79ms +step:373/1670 train_time:36101ms step_avg:96.79ms +step:374/1670 train_time:36198ms step_avg:96.79ms +step:375/1670 train_time:36294ms step_avg:96.78ms +step:375/1670 val_loss:3.8171 train_time:36390ms step_avg:97.04ms +step:376/1670 train_time:36411ms step_avg:96.84ms +step:377/1670 train_time:36492ms step_avg:96.79ms +step:378/1670 train_time:36593ms step_avg:96.81ms +step:379/1670 train_time:36688ms step_avg:96.80ms +step:380/1670 train_time:36783ms step_avg:96.80ms +step:381/1670 train_time:36878ms step_avg:96.79ms +step:382/1670 train_time:36973ms step_avg:96.79ms +step:383/1670 train_time:37067ms step_avg:96.78ms +step:384/1670 train_time:37163ms step_avg:96.78ms +step:385/1670 train_time:37257ms step_avg:96.77ms +step:386/1670 train_time:37353ms step_avg:96.77ms +step:387/1670 train_time:37449ms step_avg:96.77ms +step:388/1670 train_time:37547ms step_avg:96.77ms +step:389/1670 train_time:37644ms step_avg:96.77ms +step:390/1670 train_time:37740ms step_avg:96.77ms +step:391/1670 train_time:37835ms step_avg:96.76ms +step:392/1670 train_time:37930ms step_avg:96.76ms +step:393/1670 train_time:38025ms step_avg:96.76ms +step:394/1670 train_time:38120ms step_avg:96.75ms +step:395/1670 train_time:38215ms step_avg:96.75ms +step:396/1670 train_time:38310ms step_avg:96.74ms +step:397/1670 train_time:38407ms step_avg:96.74ms +step:398/1670 train_time:38504ms step_avg:96.74ms +step:399/1670 train_time:38600ms step_avg:96.74ms +step:400/1670 train_time:38696ms step_avg:96.74ms +step:401/1670 train_time:38791ms step_avg:96.74ms +step:402/1670 train_time:38887ms step_avg:96.73ms +step:403/1670 train_time:38983ms step_avg:96.73ms +step:404/1670 train_time:39079ms step_avg:96.73ms +step:405/1670 train_time:39174ms step_avg:96.73ms +step:406/1670 train_time:39269ms step_avg:96.72ms +step:407/1670 train_time:39365ms step_avg:96.72ms +step:408/1670 train_time:39462ms step_avg:96.72ms +step:409/1670 train_time:39559ms step_avg:96.72ms +step:410/1670 train_time:39654ms step_avg:96.72ms +step:411/1670 train_time:39750ms step_avg:96.71ms +step:412/1670 train_time:39845ms step_avg:96.71ms +step:413/1670 train_time:39941ms step_avg:96.71ms +step:414/1670 train_time:40037ms step_avg:96.71ms +step:415/1670 train_time:40131ms step_avg:96.70ms +step:416/1670 train_time:40227ms step_avg:96.70ms +step:417/1670 train_time:40323ms step_avg:96.70ms +step:418/1670 train_time:40420ms step_avg:96.70ms +step:419/1670 train_time:40517ms step_avg:96.70ms +step:420/1670 train_time:40612ms step_avg:96.70ms +step:421/1670 train_time:40708ms step_avg:96.69ms +step:422/1670 train_time:40805ms step_avg:96.69ms +step:423/1670 train_time:40901ms step_avg:96.69ms +step:424/1670 train_time:40996ms step_avg:96.69ms +step:425/1670 train_time:41286ms step_avg:97.14ms +step:426/1670 train_time:41453ms step_avg:97.31ms +step:427/1670 train_time:41546ms step_avg:97.30ms +step:428/1670 train_time:41641ms step_avg:97.29ms +step:429/1670 train_time:41735ms step_avg:97.28ms +step:430/1670 train_time:41829ms step_avg:97.28ms +step:431/1670 train_time:41924ms step_avg:97.27ms +step:432/1670 train_time:42018ms step_avg:97.26ms +step:433/1670 train_time:42112ms step_avg:97.26ms +step:434/1670 train_time:42207ms step_avg:97.25ms +step:435/1670 train_time:42303ms step_avg:97.25ms +step:436/1670 train_time:42405ms step_avg:97.26ms +step:437/1670 train_time:42506ms step_avg:97.27ms +step:438/1670 train_time:42602ms step_avg:97.27ms +step:439/1670 train_time:42697ms step_avg:97.26ms +step:440/1670 train_time:42792ms step_avg:97.26ms +step:441/1670 train_time:42887ms step_avg:97.25ms +step:442/1670 train_time:42982ms step_avg:97.24ms +step:443/1670 train_time:43076ms step_avg:97.24ms +step:444/1670 train_time:43171ms step_avg:97.23ms +step:445/1670 train_time:43266ms step_avg:97.23ms +step:446/1670 train_time:43363ms step_avg:97.23ms +step:447/1670 train_time:43461ms step_avg:97.23ms +step:448/1670 train_time:43558ms step_avg:97.23ms +step:449/1670 train_time:43653ms step_avg:97.22ms +step:450/1670 train_time:43748ms step_avg:97.22ms +step:451/1670 train_time:43844ms step_avg:97.21ms +step:452/1670 train_time:43939ms step_avg:97.21ms +step:453/1670 train_time:44034ms step_avg:97.21ms +step:454/1670 train_time:44128ms step_avg:97.20ms +step:455/1670 train_time:44224ms step_avg:97.20ms +step:456/1670 train_time:44320ms step_avg:97.19ms +step:457/1670 train_time:44415ms step_avg:97.19ms +step:458/1670 train_time:44511ms step_avg:97.19ms +step:459/1670 train_time:44608ms step_avg:97.18ms +step:460/1670 train_time:44704ms step_avg:97.18ms +step:461/1670 train_time:44800ms step_avg:97.18ms +step:462/1670 train_time:44896ms step_avg:97.18ms +step:463/1670 train_time:44991ms step_avg:97.17ms +step:464/1670 train_time:45086ms step_avg:97.17ms +step:465/1670 train_time:45181ms step_avg:97.16ms +step:466/1670 train_time:45277ms step_avg:97.16ms +step:467/1670 train_time:45372ms step_avg:97.16ms +step:468/1670 train_time:45468ms step_avg:97.15ms +step:469/1670 train_time:45566ms step_avg:97.16ms +step:470/1670 train_time:45662ms step_avg:97.15ms +step:471/1670 train_time:45758ms step_avg:97.15ms +step:472/1670 train_time:45853ms step_avg:97.15ms +step:473/1670 train_time:45948ms step_avg:97.14ms +step:474/1670 train_time:46045ms step_avg:97.14ms +step:475/1670 train_time:46140ms step_avg:97.14ms +step:476/1670 train_time:46235ms step_avg:97.13ms +step:477/1670 train_time:46330ms step_avg:97.13ms +step:478/1670 train_time:46425ms step_avg:97.12ms +step:479/1670 train_time:46521ms step_avg:97.12ms +step:480/1670 train_time:46618ms step_avg:97.12ms +step:481/1670 train_time:46713ms step_avg:97.12ms +step:482/1670 train_time:46809ms step_avg:97.11ms +step:483/1670 train_time:46904ms step_avg:97.11ms +step:484/1670 train_time:47000ms step_avg:97.11ms +step:485/1670 train_time:47095ms step_avg:97.10ms +step:486/1670 train_time:47190ms step_avg:97.10ms +step:487/1670 train_time:47285ms step_avg:97.09ms +step:488/1670 train_time:47381ms step_avg:97.09ms +step:489/1670 train_time:47477ms step_avg:97.09ms +step:490/1670 train_time:47573ms step_avg:97.09ms +step:491/1670 train_time:47669ms step_avg:97.08ms +step:492/1670 train_time:47765ms step_avg:97.08ms +step:493/1670 train_time:47861ms step_avg:97.08ms +step:494/1670 train_time:47957ms step_avg:97.08ms +step:495/1670 train_time:48052ms step_avg:97.07ms +step:496/1670 train_time:48148ms step_avg:97.07ms +step:497/1670 train_time:48243ms step_avg:97.07ms +step:498/1670 train_time:48338ms step_avg:97.07ms +step:499/1670 train_time:48433ms step_avg:97.06ms +step:500/1670 train_time:48528ms step_avg:97.06ms +step:500/1670 val_loss:3.7150 train_time:48624ms step_avg:97.25ms +step:501/1670 train_time:48645ms step_avg:97.10ms +step:502/1670 train_time:48726ms step_avg:97.06ms +step:503/1670 train_time:48826ms step_avg:97.07ms +step:504/1670 train_time:48922ms step_avg:97.07ms +step:505/1670 train_time:49016ms step_avg:97.06ms +step:506/1670 train_time:49111ms step_avg:97.06ms +step:507/1670 train_time:49206ms step_avg:97.05ms +step:508/1670 train_time:49301ms step_avg:97.05ms +step:509/1670 train_time:49396ms step_avg:97.05ms +step:510/1670 train_time:49491ms step_avg:97.04ms +step:511/1670 train_time:49586ms step_avg:97.04ms +step:512/1670 train_time:49684ms step_avg:97.04ms +step:513/1670 train_time:49783ms step_avg:97.04ms +step:514/1670 train_time:49881ms step_avg:97.04ms +step:515/1670 train_time:49976ms step_avg:97.04ms +step:516/1670 train_time:50071ms step_avg:97.04ms +step:517/1670 train_time:50166ms step_avg:97.03ms +step:518/1670 train_time:50261ms step_avg:97.03ms +step:519/1670 train_time:50357ms step_avg:97.03ms +step:520/1670 train_time:50453ms step_avg:97.02ms +step:521/1670 train_time:50548ms step_avg:97.02ms +step:522/1670 train_time:50643ms step_avg:97.02ms +step:523/1670 train_time:50740ms step_avg:97.02ms +step:524/1670 train_time:50838ms step_avg:97.02ms +step:525/1670 train_time:50934ms step_avg:97.02ms +step:526/1670 train_time:51030ms step_avg:97.02ms +step:527/1670 train_time:51125ms step_avg:97.01ms +step:528/1670 train_time:51221ms step_avg:97.01ms +step:529/1670 train_time:51316ms step_avg:97.00ms +step:530/1670 train_time:51411ms step_avg:97.00ms +step:531/1670 train_time:51506ms step_avg:97.00ms +step:532/1670 train_time:51602ms step_avg:97.00ms +step:533/1670 train_time:51698ms step_avg:96.99ms +step:534/1670 train_time:51795ms step_avg:96.99ms +step:535/1670 train_time:51891ms step_avg:96.99ms +step:536/1670 train_time:51986ms step_avg:96.99ms +step:537/1670 train_time:52083ms step_avg:96.99ms +step:538/1670 train_time:52179ms step_avg:96.99ms +step:539/1670 train_time:52274ms step_avg:96.98ms +step:540/1670 train_time:52369ms step_avg:96.98ms +step:541/1670 train_time:52464ms step_avg:96.98ms +step:542/1670 train_time:52560ms step_avg:96.97ms +step:543/1670 train_time:52656ms step_avg:96.97ms +step:544/1670 train_time:52752ms step_avg:96.97ms +step:545/1670 train_time:52847ms step_avg:96.97ms +step:546/1670 train_time:52943ms step_avg:96.97ms +step:547/1670 train_time:53041ms step_avg:96.97ms +step:548/1670 train_time:53136ms step_avg:96.96ms +step:549/1670 train_time:53231ms step_avg:96.96ms +step:550/1670 train_time:53325ms step_avg:96.96ms +step:551/1670 train_time:53421ms step_avg:96.95ms +step:552/1670 train_time:53518ms step_avg:96.95ms +step:553/1670 train_time:53614ms step_avg:96.95ms +step:554/1670 train_time:53709ms step_avg:96.95ms +step:555/1670 train_time:53805ms step_avg:96.95ms +step:556/1670 train_time:53901ms step_avg:96.94ms +step:557/1670 train_time:53997ms step_avg:96.94ms +step:558/1670 train_time:54094ms step_avg:96.94ms +step:559/1670 train_time:54190ms step_avg:96.94ms +step:560/1670 train_time:54287ms step_avg:96.94ms +step:561/1670 train_time:54383ms step_avg:96.94ms +step:562/1670 train_time:54480ms step_avg:96.94ms +step:563/1670 train_time:54578ms step_avg:96.94ms +step:564/1670 train_time:54676ms step_avg:96.94ms +step:565/1670 train_time:54773ms step_avg:96.94ms +step:566/1670 train_time:54869ms step_avg:96.94ms +step:567/1670 train_time:54966ms step_avg:96.94ms +step:568/1670 train_time:55063ms step_avg:96.94ms +step:569/1670 train_time:55160ms step_avg:96.94ms +step:570/1670 train_time:55258ms step_avg:96.94ms +step:571/1670 train_time:55355ms step_avg:96.94ms +step:572/1670 train_time:55452ms step_avg:96.94ms +step:573/1670 train_time:55548ms step_avg:96.94ms +step:574/1670 train_time:55646ms step_avg:96.94ms +step:575/1670 train_time:55743ms step_avg:96.94ms +step:576/1670 train_time:55842ms step_avg:96.95ms +step:577/1670 train_time:55939ms step_avg:96.95ms +step:578/1670 train_time:56036ms step_avg:96.95ms +step:579/1670 train_time:56133ms step_avg:96.95ms +step:580/1670 train_time:56230ms step_avg:96.95ms +step:581/1670 train_time:56327ms step_avg:96.95ms +step:582/1670 train_time:56423ms step_avg:96.95ms +step:583/1670 train_time:56520ms step_avg:96.95ms +step:584/1670 train_time:56617ms step_avg:96.95ms +step:585/1670 train_time:56715ms step_avg:96.95ms +step:586/1670 train_time:56811ms step_avg:96.95ms +step:587/1670 train_time:56908ms step_avg:96.95ms +step:588/1670 train_time:57005ms step_avg:96.95ms +step:589/1670 train_time:57102ms step_avg:96.95ms +step:590/1670 train_time:57199ms step_avg:96.95ms +step:591/1670 train_time:57297ms step_avg:96.95ms +step:592/1670 train_time:57394ms step_avg:96.95ms +step:593/1670 train_time:57491ms step_avg:96.95ms +step:594/1670 train_time:57588ms step_avg:96.95ms +step:595/1670 train_time:57685ms step_avg:96.95ms +step:596/1670 train_time:57782ms step_avg:96.95ms +step:597/1670 train_time:57880ms step_avg:96.95ms +step:598/1670 train_time:57978ms step_avg:96.95ms +step:599/1670 train_time:58074ms step_avg:96.95ms +step:600/1670 train_time:58172ms step_avg:96.95ms +step:601/1670 train_time:58268ms step_avg:96.95ms +step:602/1670 train_time:58364ms step_avg:96.95ms +step:603/1670 train_time:58461ms step_avg:96.95ms +step:604/1670 train_time:58559ms step_avg:96.95ms +step:605/1670 train_time:58656ms step_avg:96.95ms +step:606/1670 train_time:58753ms step_avg:96.95ms +step:607/1670 train_time:58850ms step_avg:96.95ms +step:608/1670 train_time:58948ms step_avg:96.95ms +step:609/1670 train_time:59045ms step_avg:96.95ms +step:610/1670 train_time:59143ms step_avg:96.96ms +step:611/1670 train_time:59241ms step_avg:96.96ms +step:612/1670 train_time:59339ms step_avg:96.96ms +step:613/1670 train_time:59436ms step_avg:96.96ms +step:614/1670 train_time:59533ms step_avg:96.96ms +step:615/1670 train_time:59630ms step_avg:96.96ms +step:616/1670 train_time:59727ms step_avg:96.96ms +step:617/1670 train_time:59824ms step_avg:96.96ms +step:618/1670 train_time:59924ms step_avg:96.96ms +step:619/1670 train_time:60021ms step_avg:96.96ms +step:620/1670 train_time:60118ms step_avg:96.96ms +step:621/1670 train_time:60216ms step_avg:96.97ms +step:622/1670 train_time:60313ms step_avg:96.97ms +step:623/1670 train_time:60409ms step_avg:96.96ms +step:624/1670 train_time:60505ms step_avg:96.96ms +step:625/1670 train_time:60603ms step_avg:96.97ms +step:625/1670 val_loss:3.6149 train_time:60701ms step_avg:97.12ms +step:626/1670 train_time:60722ms step_avg:97.00ms +step:627/1670 train_time:60808ms step_avg:96.98ms +step:628/1670 train_time:60908ms step_avg:96.99ms +step:629/1670 train_time:61005ms step_avg:96.99ms +step:630/1670 train_time:61101ms step_avg:96.99ms +step:631/1670 train_time:61197ms step_avg:96.98ms +step:632/1670 train_time:61293ms step_avg:96.98ms +step:633/1670 train_time:61389ms step_avg:96.98ms +step:634/1670 train_time:61485ms step_avg:96.98ms +step:635/1670 train_time:61581ms step_avg:96.98ms +step:636/1670 train_time:61678ms step_avg:96.98ms +step:637/1670 train_time:61776ms step_avg:96.98ms +step:638/1670 train_time:61874ms step_avg:96.98ms +step:639/1670 train_time:62168ms step_avg:97.29ms +step:640/1670 train_time:62353ms step_avg:97.43ms +step:641/1670 train_time:62448ms step_avg:97.42ms +step:642/1670 train_time:62544ms step_avg:97.42ms +step:643/1670 train_time:62640ms step_avg:97.42ms +step:644/1670 train_time:62736ms step_avg:97.42ms +step:645/1670 train_time:62832ms step_avg:97.41ms +step:646/1670 train_time:62927ms step_avg:97.41ms +step:647/1670 train_time:63023ms step_avg:97.41ms +step:648/1670 train_time:63119ms step_avg:97.41ms +step:649/1670 train_time:63218ms step_avg:97.41ms +step:650/1670 train_time:63318ms step_avg:97.41ms +step:651/1670 train_time:63417ms step_avg:97.41ms +step:652/1670 train_time:63513ms step_avg:97.41ms +step:653/1670 train_time:63609ms step_avg:97.41ms +step:654/1670 train_time:63706ms step_avg:97.41ms +step:655/1670 train_time:63803ms step_avg:97.41ms +step:656/1670 train_time:63898ms step_avg:97.41ms +step:657/1670 train_time:63993ms step_avg:97.40ms +step:658/1670 train_time:64089ms step_avg:97.40ms +step:659/1670 train_time:64188ms step_avg:97.40ms +step:660/1670 train_time:64287ms step_avg:97.41ms +step:661/1670 train_time:64386ms step_avg:97.41ms +step:662/1670 train_time:64484ms step_avg:97.41ms +step:663/1670 train_time:64581ms step_avg:97.41ms +step:664/1670 train_time:64678ms step_avg:97.41ms +step:665/1670 train_time:64774ms step_avg:97.40ms +step:666/1670 train_time:64870ms step_avg:97.40ms +step:667/1670 train_time:64967ms step_avg:97.40ms +step:668/1670 train_time:65064ms step_avg:97.40ms +step:669/1670 train_time:65161ms step_avg:97.40ms +step:670/1670 train_time:65258ms step_avg:97.40ms +step:671/1670 train_time:65355ms step_avg:97.40ms +step:672/1670 train_time:65451ms step_avg:97.40ms +step:673/1670 train_time:65549ms step_avg:97.40ms +step:674/1670 train_time:65646ms step_avg:97.40ms +step:675/1670 train_time:65744ms step_avg:97.40ms +step:676/1670 train_time:65841ms step_avg:97.40ms +step:677/1670 train_time:65938ms step_avg:97.40ms +step:678/1670 train_time:66034ms step_avg:97.40ms +step:679/1670 train_time:66130ms step_avg:97.39ms +step:680/1670 train_time:66228ms step_avg:97.39ms +step:681/1670 train_time:66327ms step_avg:97.40ms +step:682/1670 train_time:66425ms step_avg:97.40ms +step:683/1670 train_time:66523ms step_avg:97.40ms +step:684/1670 train_time:66621ms step_avg:97.40ms +step:685/1670 train_time:66717ms step_avg:97.40ms +step:686/1670 train_time:66814ms step_avg:97.40ms +step:687/1670 train_time:66911ms step_avg:97.40ms +step:688/1670 train_time:67008ms step_avg:97.39ms +step:689/1670 train_time:67104ms step_avg:97.39ms +step:690/1670 train_time:67201ms step_avg:97.39ms +step:691/1670 train_time:67298ms step_avg:97.39ms +step:692/1670 train_time:67394ms step_avg:97.39ms +step:693/1670 train_time:67491ms step_avg:97.39ms +step:694/1670 train_time:67589ms step_avg:97.39ms +step:695/1670 train_time:67686ms step_avg:97.39ms +step:696/1670 train_time:67784ms step_avg:97.39ms +step:697/1670 train_time:67881ms step_avg:97.39ms +step:698/1670 train_time:67979ms step_avg:97.39ms +step:699/1670 train_time:68075ms step_avg:97.39ms +step:700/1670 train_time:68170ms step_avg:97.39ms +step:701/1670 train_time:68267ms step_avg:97.39ms +step:702/1670 train_time:68365ms step_avg:97.39ms +step:703/1670 train_time:68463ms step_avg:97.39ms +step:704/1670 train_time:68560ms step_avg:97.39ms +step:705/1670 train_time:68657ms step_avg:97.39ms +step:706/1670 train_time:68754ms step_avg:97.38ms +step:707/1670 train_time:68851ms step_avg:97.38ms +step:708/1670 train_time:68948ms step_avg:97.38ms +step:709/1670 train_time:69046ms step_avg:97.38ms +step:710/1670 train_time:69143ms step_avg:97.38ms +step:711/1670 train_time:69240ms step_avg:97.38ms +step:712/1670 train_time:69336ms step_avg:97.38ms +step:713/1670 train_time:69433ms step_avg:97.38ms +step:714/1670 train_time:69530ms step_avg:97.38ms +step:715/1670 train_time:69628ms step_avg:97.38ms +step:716/1670 train_time:69726ms step_avg:97.38ms +step:717/1670 train_time:69824ms step_avg:97.38ms +step:718/1670 train_time:69922ms step_avg:97.38ms +step:719/1670 train_time:70018ms step_avg:97.38ms +step:720/1670 train_time:70114ms step_avg:97.38ms +step:721/1670 train_time:70210ms step_avg:97.38ms +step:722/1670 train_time:70309ms step_avg:97.38ms +step:723/1670 train_time:70406ms step_avg:97.38ms +step:724/1670 train_time:70504ms step_avg:97.38ms +step:725/1670 train_time:70601ms step_avg:97.38ms +step:726/1670 train_time:70698ms step_avg:97.38ms +step:727/1670 train_time:70795ms step_avg:97.38ms +step:728/1670 train_time:70892ms step_avg:97.38ms +step:729/1670 train_time:70989ms step_avg:97.38ms +step:730/1670 train_time:71087ms step_avg:97.38ms +step:731/1670 train_time:71184ms step_avg:97.38ms +step:732/1670 train_time:71282ms step_avg:97.38ms +step:733/1670 train_time:71379ms step_avg:97.38ms +step:734/1670 train_time:71475ms step_avg:97.38ms +step:735/1670 train_time:71572ms step_avg:97.38ms +step:736/1670 train_time:71669ms step_avg:97.38ms +step:737/1670 train_time:71767ms step_avg:97.38ms +step:738/1670 train_time:71865ms step_avg:97.38ms +step:739/1670 train_time:71962ms step_avg:97.38ms +step:740/1670 train_time:72059ms step_avg:97.38ms +step:741/1670 train_time:72156ms step_avg:97.38ms +step:742/1670 train_time:72252ms step_avg:97.38ms +step:743/1670 train_time:72350ms step_avg:97.37ms +step:744/1670 train_time:72447ms step_avg:97.37ms +step:745/1670 train_time:72545ms step_avg:97.38ms +step:746/1670 train_time:72643ms step_avg:97.38ms +step:747/1670 train_time:72740ms step_avg:97.38ms +step:748/1670 train_time:72836ms step_avg:97.37ms +step:749/1670 train_time:72932ms step_avg:97.37ms +step:750/1670 train_time:73030ms step_avg:97.37ms +step:750/1670 val_loss:3.5624 train_time:73128ms step_avg:97.50ms +step:751/1670 train_time:73148ms step_avg:97.40ms +step:752/1670 train_time:73231ms step_avg:97.38ms +step:753/1670 train_time:73332ms step_avg:97.39ms +step:754/1670 train_time:73429ms step_avg:97.39ms +step:755/1670 train_time:73525ms step_avg:97.38ms +step:756/1670 train_time:73621ms step_avg:97.38ms +step:757/1670 train_time:73717ms step_avg:97.38ms +step:758/1670 train_time:73813ms step_avg:97.38ms +step:759/1670 train_time:73910ms step_avg:97.38ms +step:760/1670 train_time:74008ms step_avg:97.38ms +step:761/1670 train_time:74107ms step_avg:97.38ms +step:762/1670 train_time:74208ms step_avg:97.39ms +step:763/1670 train_time:74308ms step_avg:97.39ms +step:764/1670 train_time:74406ms step_avg:97.39ms +step:765/1670 train_time:74503ms step_avg:97.39ms +step:766/1670 train_time:74599ms step_avg:97.39ms +step:767/1670 train_time:74695ms step_avg:97.39ms +step:768/1670 train_time:74791ms step_avg:97.38ms +step:769/1670 train_time:74887ms step_avg:97.38ms +step:770/1670 train_time:74983ms step_avg:97.38ms +step:771/1670 train_time:75081ms step_avg:97.38ms +step:772/1670 train_time:75179ms step_avg:97.38ms +step:773/1670 train_time:75277ms step_avg:97.38ms +step:774/1670 train_time:75375ms step_avg:97.38ms +step:775/1670 train_time:75472ms step_avg:97.38ms +step:776/1670 train_time:75569ms step_avg:97.38ms +step:777/1670 train_time:75667ms step_avg:97.38ms +step:778/1670 train_time:75763ms step_avg:97.38ms +step:779/1670 train_time:75859ms step_avg:97.38ms +step:780/1670 train_time:75955ms step_avg:97.38ms +step:781/1670 train_time:76051ms step_avg:97.38ms +step:782/1670 train_time:76150ms step_avg:97.38ms +step:783/1670 train_time:76248ms step_avg:97.38ms +step:784/1670 train_time:76348ms step_avg:97.38ms +step:785/1670 train_time:76446ms step_avg:97.38ms +step:786/1670 train_time:76543ms step_avg:97.38ms +step:787/1670 train_time:76641ms step_avg:97.38ms +step:788/1670 train_time:76737ms step_avg:97.38ms +step:789/1670 train_time:76834ms step_avg:97.38ms +step:790/1670 train_time:76931ms step_avg:97.38ms +step:791/1670 train_time:77028ms step_avg:97.38ms +step:792/1670 train_time:77125ms step_avg:97.38ms +step:793/1670 train_time:77222ms step_avg:97.38ms +step:794/1670 train_time:77320ms step_avg:97.38ms +step:795/1670 train_time:77417ms step_avg:97.38ms +step:796/1670 train_time:77515ms step_avg:97.38ms +step:797/1670 train_time:77612ms step_avg:97.38ms +step:798/1670 train_time:77709ms step_avg:97.38ms +step:799/1670 train_time:77806ms step_avg:97.38ms +step:800/1670 train_time:77903ms step_avg:97.38ms +step:801/1670 train_time:78000ms step_avg:97.38ms +step:802/1670 train_time:78096ms step_avg:97.38ms +step:803/1670 train_time:78193ms step_avg:97.38ms +step:804/1670 train_time:78290ms step_avg:97.38ms +step:805/1670 train_time:78387ms step_avg:97.38ms +step:806/1670 train_time:78484ms step_avg:97.38ms +step:807/1670 train_time:78582ms step_avg:97.38ms +step:808/1670 train_time:78680ms step_avg:97.38ms +step:809/1670 train_time:78776ms step_avg:97.37ms +step:810/1670 train_time:78872ms step_avg:97.37ms +step:811/1670 train_time:78968ms step_avg:97.37ms +step:812/1670 train_time:79065ms step_avg:97.37ms +step:813/1670 train_time:79163ms step_avg:97.37ms +step:814/1670 train_time:79260ms step_avg:97.37ms +step:815/1670 train_time:79357ms step_avg:97.37ms +step:816/1670 train_time:79453ms step_avg:97.37ms +step:817/1670 train_time:79551ms step_avg:97.37ms +step:818/1670 train_time:79649ms step_avg:97.37ms +step:819/1670 train_time:79747ms step_avg:97.37ms +step:820/1670 train_time:79844ms step_avg:97.37ms +step:821/1670 train_time:79941ms step_avg:97.37ms +step:822/1670 train_time:80038ms step_avg:97.37ms +step:823/1670 train_time:80135ms step_avg:97.37ms +step:824/1670 train_time:80231ms step_avg:97.37ms +step:825/1670 train_time:80329ms step_avg:97.37ms +step:826/1670 train_time:80426ms step_avg:97.37ms +step:827/1670 train_time:80523ms step_avg:97.37ms +step:828/1670 train_time:80621ms step_avg:97.37ms +step:829/1670 train_time:80718ms step_avg:97.37ms +step:830/1670 train_time:80814ms step_avg:97.37ms +step:831/1670 train_time:80910ms step_avg:97.37ms +step:832/1670 train_time:81008ms step_avg:97.37ms +step:833/1670 train_time:81105ms step_avg:97.37ms +step:834/1670 train_time:81203ms step_avg:97.37ms +step:835/1670 train_time:81300ms step_avg:97.37ms +step:836/1670 train_time:81396ms step_avg:97.36ms +step:837/1670 train_time:81493ms step_avg:97.36ms +step:838/1670 train_time:81590ms step_avg:97.36ms +step:839/1670 train_time:81688ms step_avg:97.36ms +step:840/1670 train_time:81787ms step_avg:97.37ms +step:841/1670 train_time:81885ms step_avg:97.37ms +step:842/1670 train_time:81983ms step_avg:97.37ms +step:843/1670 train_time:82080ms step_avg:97.37ms +step:844/1670 train_time:82176ms step_avg:97.36ms +step:845/1670 train_time:82272ms step_avg:97.36ms +step:846/1670 train_time:82369ms step_avg:97.36ms +step:847/1670 train_time:82467ms step_avg:97.36ms +step:848/1670 train_time:82564ms step_avg:97.36ms +step:849/1670 train_time:82662ms step_avg:97.36ms +step:850/1670 train_time:82759ms step_avg:97.36ms +step:851/1670 train_time:83018ms step_avg:97.55ms +step:852/1670 train_time:83183ms step_avg:97.63ms +step:853/1670 train_time:83279ms step_avg:97.63ms +step:854/1670 train_time:83375ms step_avg:97.63ms +step:855/1670 train_time:83470ms step_avg:97.63ms +step:856/1670 train_time:83566ms step_avg:97.62ms +step:857/1670 train_time:83662ms step_avg:97.62ms +step:858/1670 train_time:83758ms step_avg:97.62ms +step:859/1670 train_time:83854ms step_avg:97.62ms +step:860/1670 train_time:83950ms step_avg:97.62ms +step:861/1670 train_time:84053ms step_avg:97.62ms +step:862/1670 train_time:84152ms step_avg:97.62ms +step:863/1670 train_time:84250ms step_avg:97.62ms +step:864/1670 train_time:84347ms step_avg:97.62ms +step:865/1670 train_time:84445ms step_avg:97.62ms +step:866/1670 train_time:84541ms step_avg:97.62ms +step:867/1670 train_time:84636ms step_avg:97.62ms +step:868/1670 train_time:84732ms step_avg:97.62ms +step:869/1670 train_time:84828ms step_avg:97.62ms +step:870/1670 train_time:84926ms step_avg:97.62ms +step:871/1670 train_time:85026ms step_avg:97.62ms +step:872/1670 train_time:85125ms step_avg:97.62ms +step:873/1670 train_time:85223ms step_avg:97.62ms +step:874/1670 train_time:85321ms step_avg:97.62ms +step:875/1670 train_time:85419ms step_avg:97.62ms +step:875/1670 val_loss:3.5190 train_time:85515ms step_avg:97.73ms +step:876/1670 train_time:85536ms step_avg:97.64ms +step:877/1670 train_time:85620ms step_avg:97.63ms +step:878/1670 train_time:85718ms step_avg:97.63ms +step:879/1670 train_time:85814ms step_avg:97.63ms +step:880/1670 train_time:85911ms step_avg:97.63ms +step:881/1670 train_time:86007ms step_avg:97.62ms +step:882/1670 train_time:86102ms step_avg:97.62ms +step:883/1670 train_time:86198ms step_avg:97.62ms +step:884/1670 train_time:86294ms step_avg:97.62ms +step:885/1670 train_time:86390ms step_avg:97.62ms +step:886/1670 train_time:86489ms step_avg:97.62ms +step:887/1670 train_time:86589ms step_avg:97.62ms +step:888/1670 train_time:86690ms step_avg:97.62ms +step:889/1670 train_time:86789ms step_avg:97.63ms +step:890/1670 train_time:86886ms step_avg:97.62ms +step:891/1670 train_time:86982ms step_avg:97.62ms +step:892/1670 train_time:87078ms step_avg:97.62ms +step:893/1670 train_time:87174ms step_avg:97.62ms +step:894/1670 train_time:87270ms step_avg:97.62ms +step:895/1670 train_time:87367ms step_avg:97.62ms +step:896/1670 train_time:87464ms step_avg:97.62ms +step:897/1670 train_time:87563ms step_avg:97.62ms +step:898/1670 train_time:87661ms step_avg:97.62ms +step:899/1670 train_time:87758ms step_avg:97.62ms +step:900/1670 train_time:87855ms step_avg:97.62ms +step:901/1670 train_time:87953ms step_avg:97.62ms +step:902/1670 train_time:88050ms step_avg:97.62ms +step:903/1670 train_time:88147ms step_avg:97.62ms +step:904/1670 train_time:88243ms step_avg:97.61ms +step:905/1670 train_time:88339ms step_avg:97.61ms +step:906/1670 train_time:88435ms step_avg:97.61ms +step:907/1670 train_time:88533ms step_avg:97.61ms +step:908/1670 train_time:88632ms step_avg:97.61ms +step:909/1670 train_time:88730ms step_avg:97.61ms +step:910/1670 train_time:88829ms step_avg:97.61ms +step:911/1670 train_time:88927ms step_avg:97.61ms +step:912/1670 train_time:89023ms step_avg:97.61ms +step:913/1670 train_time:89120ms step_avg:97.61ms +step:914/1670 train_time:89216ms step_avg:97.61ms +step:915/1670 train_time:89312ms step_avg:97.61ms +step:916/1670 train_time:89409ms step_avg:97.61ms +step:917/1670 train_time:89507ms step_avg:97.61ms +step:918/1670 train_time:89605ms step_avg:97.61ms +step:919/1670 train_time:89702ms step_avg:97.61ms +step:920/1670 train_time:89799ms step_avg:97.61ms +step:921/1670 train_time:89897ms step_avg:97.61ms +step:922/1670 train_time:89994ms step_avg:97.61ms +step:923/1670 train_time:90092ms step_avg:97.61ms +step:924/1670 train_time:90189ms step_avg:97.61ms +step:925/1670 train_time:90286ms step_avg:97.61ms +step:926/1670 train_time:90382ms step_avg:97.61ms +step:927/1670 train_time:90479ms step_avg:97.60ms +step:928/1670 train_time:90575ms step_avg:97.60ms +step:929/1670 train_time:90674ms step_avg:97.60ms +step:930/1670 train_time:90773ms step_avg:97.60ms +step:931/1670 train_time:90871ms step_avg:97.61ms +step:932/1670 train_time:90969ms step_avg:97.61ms +step:933/1670 train_time:91066ms step_avg:97.61ms +step:934/1670 train_time:91163ms step_avg:97.61ms +step:935/1670 train_time:91259ms step_avg:97.60ms +step:936/1670 train_time:91355ms step_avg:97.60ms +step:937/1670 train_time:91452ms step_avg:97.60ms +step:938/1670 train_time:91550ms step_avg:97.60ms +step:939/1670 train_time:91648ms step_avg:97.60ms +step:940/1670 train_time:91746ms step_avg:97.60ms +step:941/1670 train_time:91843ms step_avg:97.60ms +step:942/1670 train_time:91941ms step_avg:97.60ms +step:943/1670 train_time:92037ms step_avg:97.60ms +step:944/1670 train_time:92134ms step_avg:97.60ms +step:945/1670 train_time:92232ms step_avg:97.60ms +step:946/1670 train_time:92329ms step_avg:97.60ms +step:947/1670 train_time:92427ms step_avg:97.60ms +step:948/1670 train_time:92525ms step_avg:97.60ms +step:949/1670 train_time:92622ms step_avg:97.60ms +step:950/1670 train_time:92717ms step_avg:97.60ms +step:951/1670 train_time:92814ms step_avg:97.60ms +step:952/1670 train_time:92912ms step_avg:97.60ms +step:953/1670 train_time:93010ms step_avg:97.60ms +step:954/1670 train_time:93107ms step_avg:97.60ms +step:955/1670 train_time:93204ms step_avg:97.60ms +step:956/1670 train_time:93300ms step_avg:97.59ms +step:957/1670 train_time:93398ms step_avg:97.59ms +step:958/1670 train_time:93495ms step_avg:97.59ms +step:959/1670 train_time:93592ms step_avg:97.59ms +step:960/1670 train_time:93689ms step_avg:97.59ms +step:961/1670 train_time:93786ms step_avg:97.59ms +step:962/1670 train_time:93884ms step_avg:97.59ms +step:963/1670 train_time:93981ms step_avg:97.59ms +step:964/1670 train_time:94077ms step_avg:97.59ms +step:965/1670 train_time:94174ms step_avg:97.59ms +step:966/1670 train_time:94273ms step_avg:97.59ms +step:967/1670 train_time:94371ms step_avg:97.59ms +step:968/1670 train_time:94469ms step_avg:97.59ms +step:969/1670 train_time:94567ms step_avg:97.59ms +step:970/1670 train_time:94663ms step_avg:97.59ms +step:971/1670 train_time:94760ms step_avg:97.59ms +step:972/1670 train_time:94856ms step_avg:97.59ms +step:973/1670 train_time:94953ms step_avg:97.59ms +step:974/1670 train_time:95051ms step_avg:97.59ms +step:975/1670 train_time:95150ms step_avg:97.59ms +step:976/1670 train_time:95248ms step_avg:97.59ms +step:977/1670 train_time:95346ms step_avg:97.59ms +step:978/1670 train_time:95443ms step_avg:97.59ms +step:979/1670 train_time:95540ms step_avg:97.59ms +step:980/1670 train_time:95636ms step_avg:97.59ms +step:981/1670 train_time:95733ms step_avg:97.59ms +step:982/1670 train_time:95831ms step_avg:97.59ms +step:983/1670 train_time:95929ms step_avg:97.59ms +step:984/1670 train_time:96026ms step_avg:97.59ms +step:985/1670 train_time:96122ms step_avg:97.59ms +step:986/1670 train_time:96219ms step_avg:97.58ms +step:987/1670 train_time:96315ms step_avg:97.58ms +step:988/1670 train_time:96413ms step_avg:97.58ms +step:989/1670 train_time:96511ms step_avg:97.58ms +step:990/1670 train_time:96609ms step_avg:97.59ms +step:991/1670 train_time:96707ms step_avg:97.59ms +step:992/1670 train_time:96804ms step_avg:97.58ms +step:993/1670 train_time:96901ms step_avg:97.58ms +step:994/1670 train_time:96997ms step_avg:97.58ms +step:995/1670 train_time:97094ms step_avg:97.58ms +step:996/1670 train_time:97192ms step_avg:97.58ms +step:997/1670 train_time:97290ms step_avg:97.58ms +step:998/1670 train_time:97387ms step_avg:97.58ms +step:999/1670 train_time:97485ms step_avg:97.58ms +step:1000/1670 train_time:97582ms step_avg:97.58ms +step:1000/1670 val_loss:3.4774 train_time:97679ms step_avg:97.68ms +step:1001/1670 train_time:97700ms step_avg:97.60ms +step:1002/1670 train_time:97786ms step_avg:97.59ms +step:1003/1670 train_time:97886ms step_avg:97.59ms +step:1004/1670 train_time:97984ms step_avg:97.59ms +step:1005/1670 train_time:98079ms step_avg:97.59ms +step:1006/1670 train_time:98175ms step_avg:97.59ms +step:1007/1670 train_time:98271ms step_avg:97.59ms +step:1008/1670 train_time:98367ms step_avg:97.59ms +step:1009/1670 train_time:98464ms step_avg:97.59ms +step:1010/1670 train_time:98560ms step_avg:97.58ms +step:1011/1670 train_time:98657ms step_avg:97.58ms +step:1012/1670 train_time:98756ms step_avg:97.58ms +step:1013/1670 train_time:98854ms step_avg:97.59ms +step:1014/1670 train_time:98953ms step_avg:97.59ms +step:1015/1670 train_time:99050ms step_avg:97.59ms +step:1016/1670 train_time:99148ms step_avg:97.59ms +step:1017/1670 train_time:99244ms step_avg:97.59ms +step:1018/1670 train_time:99341ms step_avg:97.58ms +step:1019/1670 train_time:99437ms step_avg:97.58ms +step:1020/1670 train_time:99533ms step_avg:97.58ms +step:1021/1670 train_time:99631ms step_avg:97.58ms +step:1022/1670 train_time:99729ms step_avg:97.58ms +step:1023/1670 train_time:99828ms step_avg:97.58ms +step:1024/1670 train_time:99928ms step_avg:97.59ms +step:1025/1670 train_time:100026ms step_avg:97.59ms +step:1026/1670 train_time:100123ms step_avg:97.59ms +step:1027/1670 train_time:100219ms step_avg:97.58ms +step:1028/1670 train_time:100316ms step_avg:97.58ms +step:1029/1670 train_time:100412ms step_avg:97.58ms +step:1030/1670 train_time:100509ms step_avg:97.58ms +step:1031/1670 train_time:100605ms step_avg:97.58ms +step:1032/1670 train_time:100702ms step_avg:97.58ms +step:1033/1670 train_time:100799ms step_avg:97.58ms +step:1034/1670 train_time:100897ms step_avg:97.58ms +step:1035/1670 train_time:100994ms step_avg:97.58ms +step:1036/1670 train_time:101091ms step_avg:97.58ms +step:1037/1670 train_time:101190ms step_avg:97.58ms +step:1038/1670 train_time:101288ms step_avg:97.58ms +step:1039/1670 train_time:101385ms step_avg:97.58ms +step:1040/1670 train_time:101482ms step_avg:97.58ms +step:1041/1670 train_time:101578ms step_avg:97.58ms +step:1042/1670 train_time:101674ms step_avg:97.58ms +step:1043/1670 train_time:101772ms step_avg:97.58ms +step:1044/1670 train_time:101871ms step_avg:97.58ms +step:1045/1670 train_time:101970ms step_avg:97.58ms +step:1046/1670 train_time:102068ms step_avg:97.58ms +step:1047/1670 train_time:102165ms step_avg:97.58ms +step:1048/1670 train_time:102263ms step_avg:97.58ms +step:1049/1670 train_time:102360ms step_avg:97.58ms +step:1050/1670 train_time:102456ms step_avg:97.58ms +step:1051/1670 train_time:102552ms step_avg:97.58ms +step:1052/1670 train_time:102649ms step_avg:97.58ms +step:1053/1670 train_time:102747ms step_avg:97.58ms +step:1054/1670 train_time:102844ms step_avg:97.58ms +step:1055/1670 train_time:102942ms step_avg:97.57ms +step:1056/1670 train_time:103038ms step_avg:97.57ms +step:1057/1670 train_time:103135ms step_avg:97.57ms +step:1058/1670 train_time:103233ms step_avg:97.57ms +step:1059/1670 train_time:103331ms step_avg:97.57ms +step:1060/1670 train_time:103429ms step_avg:97.57ms +step:1061/1670 train_time:103527ms step_avg:97.57ms +step:1062/1670 train_time:103793ms step_avg:97.73ms +step:1063/1670 train_time:103890ms step_avg:97.73ms +step:1064/1670 train_time:103986ms step_avg:97.73ms +step:1065/1670 train_time:104081ms step_avg:97.73ms +step:1066/1670 train_time:104177ms step_avg:97.73ms +step:1067/1670 train_time:104272ms step_avg:97.72ms +step:1068/1670 train_time:104369ms step_avg:97.72ms +step:1069/1670 train_time:104466ms step_avg:97.72ms +step:1070/1670 train_time:104561ms step_avg:97.72ms +step:1071/1670 train_time:104657ms step_avg:97.72ms +step:1072/1670 train_time:104756ms step_avg:97.72ms +step:1073/1670 train_time:104856ms step_avg:97.72ms +step:1074/1670 train_time:104955ms step_avg:97.72ms +step:1075/1670 train_time:105052ms step_avg:97.72ms +step:1076/1670 train_time:105150ms step_avg:97.72ms +step:1077/1670 train_time:105246ms step_avg:97.72ms +step:1078/1670 train_time:105342ms step_avg:97.72ms +step:1079/1670 train_time:105438ms step_avg:97.72ms +step:1080/1670 train_time:105534ms step_avg:97.72ms +step:1081/1670 train_time:105631ms step_avg:97.72ms +step:1082/1670 train_time:105729ms step_avg:97.72ms +step:1083/1670 train_time:105830ms step_avg:97.72ms +step:1084/1670 train_time:105929ms step_avg:97.72ms +step:1085/1670 train_time:106026ms step_avg:97.72ms +step:1086/1670 train_time:106124ms step_avg:97.72ms +step:1087/1670 train_time:106220ms step_avg:97.72ms +step:1088/1670 train_time:106316ms step_avg:97.72ms +step:1089/1670 train_time:106412ms step_avg:97.72ms +step:1090/1670 train_time:106510ms step_avg:97.72ms +step:1091/1670 train_time:106606ms step_avg:97.71ms +step:1092/1670 train_time:106703ms step_avg:97.71ms +step:1093/1670 train_time:106801ms step_avg:97.71ms +step:1094/1670 train_time:106899ms step_avg:97.71ms +step:1095/1670 train_time:106995ms step_avg:97.71ms +step:1096/1670 train_time:107093ms step_avg:97.71ms +step:1097/1670 train_time:107191ms step_avg:97.71ms +step:1098/1670 train_time:107288ms step_avg:97.71ms +step:1099/1670 train_time:107384ms step_avg:97.71ms +step:1100/1670 train_time:107480ms step_avg:97.71ms +step:1101/1670 train_time:107576ms step_avg:97.71ms +step:1102/1670 train_time:107673ms step_avg:97.71ms +step:1103/1670 train_time:107771ms step_avg:97.71ms +step:1104/1670 train_time:107869ms step_avg:97.71ms +step:1105/1670 train_time:107967ms step_avg:97.71ms +step:1106/1670 train_time:108065ms step_avg:97.71ms +step:1107/1670 train_time:108162ms step_avg:97.71ms +step:1108/1670 train_time:108259ms step_avg:97.71ms +step:1109/1670 train_time:108355ms step_avg:97.71ms +step:1110/1670 train_time:108452ms step_avg:97.70ms +step:1111/1670 train_time:108549ms step_avg:97.70ms +step:1112/1670 train_time:108646ms step_avg:97.70ms +step:1113/1670 train_time:108743ms step_avg:97.70ms +step:1114/1670 train_time:108840ms step_avg:97.70ms +step:1115/1670 train_time:108937ms step_avg:97.70ms +step:1116/1670 train_time:109034ms step_avg:97.70ms +step:1117/1670 train_time:109132ms step_avg:97.70ms +step:1118/1670 train_time:109232ms step_avg:97.70ms +step:1119/1670 train_time:109331ms step_avg:97.70ms +step:1120/1670 train_time:109429ms step_avg:97.70ms +step:1121/1670 train_time:109528ms step_avg:97.71ms +step:1122/1670 train_time:109626ms step_avg:97.71ms +step:1123/1670 train_time:109725ms step_avg:97.71ms +step:1124/1670 train_time:109823ms step_avg:97.71ms +step:1125/1670 train_time:109920ms step_avg:97.71ms +step:1125/1670 val_loss:3.4232 train_time:110017ms step_avg:97.79ms +step:1126/1670 train_time:110039ms step_avg:97.73ms +step:1127/1670 train_time:110126ms step_avg:97.72ms +step:1128/1670 train_time:110224ms step_avg:97.72ms +step:1129/1670 train_time:110320ms step_avg:97.72ms +step:1130/1670 train_time:110417ms step_avg:97.71ms +step:1131/1670 train_time:110514ms step_avg:97.71ms +step:1132/1670 train_time:110611ms step_avg:97.71ms +step:1133/1670 train_time:110708ms step_avg:97.71ms +step:1134/1670 train_time:110805ms step_avg:97.71ms +step:1135/1670 train_time:110901ms step_avg:97.71ms +step:1136/1670 train_time:111003ms step_avg:97.71ms +step:1137/1670 train_time:111103ms step_avg:97.72ms +step:1138/1670 train_time:111201ms step_avg:97.72ms +step:1139/1670 train_time:111299ms step_avg:97.72ms +step:1140/1670 train_time:111396ms step_avg:97.72ms +step:1141/1670 train_time:111492ms step_avg:97.71ms +step:1142/1670 train_time:111590ms step_avg:97.71ms +step:1143/1670 train_time:111686ms step_avg:97.71ms +step:1144/1670 train_time:111783ms step_avg:97.71ms +step:1145/1670 train_time:111881ms step_avg:97.71ms +step:1146/1670 train_time:111981ms step_avg:97.71ms +step:1147/1670 train_time:112082ms step_avg:97.72ms +step:1148/1670 train_time:112181ms step_avg:97.72ms +step:1149/1670 train_time:112280ms step_avg:97.72ms +step:1150/1670 train_time:112378ms step_avg:97.72ms +step:1151/1670 train_time:112475ms step_avg:97.72ms +step:1152/1670 train_time:112572ms step_avg:97.72ms +step:1153/1670 train_time:112669ms step_avg:97.72ms +step:1154/1670 train_time:112766ms step_avg:97.72ms +step:1155/1670 train_time:112863ms step_avg:97.72ms +step:1156/1670 train_time:112961ms step_avg:97.72ms +step:1157/1670 train_time:113061ms step_avg:97.72ms +step:1158/1670 train_time:113160ms step_avg:97.72ms +step:1159/1670 train_time:113259ms step_avg:97.72ms +step:1160/1670 train_time:113357ms step_avg:97.72ms +step:1161/1670 train_time:113455ms step_avg:97.72ms +step:1162/1670 train_time:113552ms step_avg:97.72ms +step:1163/1670 train_time:113649ms step_avg:97.72ms +step:1164/1670 train_time:113747ms step_avg:97.72ms +step:1165/1670 train_time:113843ms step_avg:97.72ms +step:1166/1670 train_time:113940ms step_avg:97.72ms +step:1167/1670 train_time:114039ms step_avg:97.72ms +step:1168/1670 train_time:114138ms step_avg:97.72ms +step:1169/1670 train_time:114237ms step_avg:97.72ms +step:1170/1670 train_time:114334ms step_avg:97.72ms +step:1171/1670 train_time:114433ms step_avg:97.72ms +step:1172/1670 train_time:114530ms step_avg:97.72ms +step:1173/1670 train_time:114627ms step_avg:97.72ms +step:1174/1670 train_time:114724ms step_avg:97.72ms +step:1175/1670 train_time:114821ms step_avg:97.72ms +step:1176/1670 train_time:114919ms step_avg:97.72ms +step:1177/1670 train_time:115017ms step_avg:97.72ms +step:1178/1670 train_time:115116ms step_avg:97.72ms +step:1179/1670 train_time:115215ms step_avg:97.72ms +step:1180/1670 train_time:115313ms step_avg:97.72ms +step:1181/1670 train_time:115410ms step_avg:97.72ms +step:1182/1670 train_time:115507ms step_avg:97.72ms +step:1183/1670 train_time:115604ms step_avg:97.72ms +step:1184/1670 train_time:115701ms step_avg:97.72ms +step:1185/1670 train_time:115798ms step_avg:97.72ms +step:1186/1670 train_time:115897ms step_avg:97.72ms +step:1187/1670 train_time:115995ms step_avg:97.72ms +step:1188/1670 train_time:116092ms step_avg:97.72ms +step:1189/1670 train_time:116190ms step_avg:97.72ms +step:1190/1670 train_time:116288ms step_avg:97.72ms +step:1191/1670 train_time:116385ms step_avg:97.72ms +step:1192/1670 train_time:116483ms step_avg:97.72ms +step:1193/1670 train_time:116581ms step_avg:97.72ms +step:1194/1670 train_time:116679ms step_avg:97.72ms +step:1195/1670 train_time:116777ms step_avg:97.72ms +step:1196/1670 train_time:116874ms step_avg:97.72ms +step:1197/1670 train_time:116972ms step_avg:97.72ms +step:1198/1670 train_time:117069ms step_avg:97.72ms +step:1199/1670 train_time:117167ms step_avg:97.72ms +step:1200/1670 train_time:117264ms step_avg:97.72ms +step:1201/1670 train_time:117361ms step_avg:97.72ms +step:1202/1670 train_time:117460ms step_avg:97.72ms +step:1203/1670 train_time:117559ms step_avg:97.72ms +step:1204/1670 train_time:117658ms step_avg:97.72ms +step:1205/1670 train_time:117756ms step_avg:97.72ms +step:1206/1670 train_time:117854ms step_avg:97.72ms +step:1207/1670 train_time:117951ms step_avg:97.72ms +step:1208/1670 train_time:118050ms step_avg:97.72ms +step:1209/1670 train_time:118147ms step_avg:97.72ms +step:1210/1670 train_time:118244ms step_avg:97.72ms +step:1211/1670 train_time:118342ms step_avg:97.72ms +step:1212/1670 train_time:118439ms step_avg:97.72ms +step:1213/1670 train_time:118537ms step_avg:97.72ms +step:1214/1670 train_time:118635ms step_avg:97.72ms +step:1215/1670 train_time:118733ms step_avg:97.72ms +step:1216/1670 train_time:118830ms step_avg:97.72ms +step:1217/1670 train_time:118928ms step_avg:97.72ms +step:1218/1670 train_time:119026ms step_avg:97.72ms +step:1219/1670 train_time:119124ms step_avg:97.72ms +step:1220/1670 train_time:119221ms step_avg:97.72ms +step:1221/1670 train_time:119319ms step_avg:97.72ms +step:1222/1670 train_time:119417ms step_avg:97.72ms +step:1223/1670 train_time:119515ms step_avg:97.72ms +step:1224/1670 train_time:119613ms step_avg:97.72ms +step:1225/1670 train_time:119710ms step_avg:97.72ms +step:1226/1670 train_time:119807ms step_avg:97.72ms +step:1227/1670 train_time:119903ms step_avg:97.72ms +step:1228/1670 train_time:120002ms step_avg:97.72ms +step:1229/1670 train_time:120100ms step_avg:97.72ms +step:1230/1670 train_time:120199ms step_avg:97.72ms +step:1231/1670 train_time:120297ms step_avg:97.72ms +step:1232/1670 train_time:120396ms step_avg:97.72ms +step:1233/1670 train_time:120494ms step_avg:97.72ms +step:1234/1670 train_time:120592ms step_avg:97.72ms +step:1235/1670 train_time:120690ms step_avg:97.72ms +step:1236/1670 train_time:120787ms step_avg:97.72ms +step:1237/1670 train_time:120884ms step_avg:97.72ms +step:1238/1670 train_time:120982ms step_avg:97.72ms +step:1239/1670 train_time:121080ms step_avg:97.72ms +step:1240/1670 train_time:121179ms step_avg:97.72ms +step:1241/1670 train_time:121277ms step_avg:97.73ms +step:1242/1670 train_time:121376ms step_avg:97.73ms +step:1243/1670 train_time:121473ms step_avg:97.73ms +step:1244/1670 train_time:121571ms step_avg:97.73ms +step:1245/1670 train_time:121668ms step_avg:97.73ms +step:1246/1670 train_time:121765ms step_avg:97.72ms +step:1247/1670 train_time:121863ms step_avg:97.72ms +step:1248/1670 train_time:121961ms step_avg:97.73ms +step:1249/1670 train_time:122059ms step_avg:97.73ms +step:1250/1670 train_time:122157ms step_avg:97.73ms +step:1250/1670 val_loss:3.3799 train_time:122255ms step_avg:97.80ms +step:1251/1670 train_time:122276ms step_avg:97.74ms +step:1252/1670 train_time:122359ms step_avg:97.73ms +step:1253/1670 train_time:122462ms step_avg:97.73ms +step:1254/1670 train_time:122559ms step_avg:97.73ms +step:1255/1670 train_time:122655ms step_avg:97.73ms +step:1256/1670 train_time:122752ms step_avg:97.73ms +step:1257/1670 train_time:122849ms step_avg:97.73ms +step:1258/1670 train_time:122946ms step_avg:97.73ms +step:1259/1670 train_time:123043ms step_avg:97.73ms +step:1260/1670 train_time:123139ms step_avg:97.73ms +step:1261/1670 train_time:123237ms step_avg:97.73ms +step:1262/1670 train_time:123336ms step_avg:97.73ms +step:1263/1670 train_time:123436ms step_avg:97.73ms +step:1264/1670 train_time:123535ms step_avg:97.73ms +step:1265/1670 train_time:123632ms step_avg:97.73ms +step:1266/1670 train_time:123729ms step_avg:97.73ms +step:1267/1670 train_time:123827ms step_avg:97.73ms +step:1268/1670 train_time:123924ms step_avg:97.73ms +step:1269/1670 train_time:124021ms step_avg:97.73ms +step:1270/1670 train_time:124118ms step_avg:97.73ms +step:1271/1670 train_time:124215ms step_avg:97.73ms +step:1272/1670 train_time:124317ms step_avg:97.73ms +step:1273/1670 train_time:124416ms step_avg:97.73ms +step:1274/1670 train_time:124688ms step_avg:97.87ms +step:1275/1670 train_time:124901ms step_avg:97.96ms +step:1276/1670 train_time:124996ms step_avg:97.96ms +step:1277/1670 train_time:125093ms step_avg:97.96ms +step:1278/1670 train_time:125190ms step_avg:97.96ms +step:1279/1670 train_time:125286ms step_avg:97.96ms +step:1280/1670 train_time:125383ms step_avg:97.96ms +step:1281/1670 train_time:125479ms step_avg:97.95ms +step:1282/1670 train_time:125576ms step_avg:97.95ms +step:1283/1670 train_time:125675ms step_avg:97.95ms +step:1284/1670 train_time:125776ms step_avg:97.96ms +step:1285/1670 train_time:125878ms step_avg:97.96ms +step:1286/1670 train_time:125977ms step_avg:97.96ms +step:1287/1670 train_time:126074ms step_avg:97.96ms +step:1288/1670 train_time:126171ms step_avg:97.96ms +step:1289/1670 train_time:126269ms step_avg:97.96ms +step:1290/1670 train_time:126366ms step_avg:97.96ms +step:1291/1670 train_time:126463ms step_avg:97.96ms +step:1292/1670 train_time:126560ms step_avg:97.96ms +step:1293/1670 train_time:126657ms step_avg:97.96ms +step:1294/1670 train_time:126757ms step_avg:97.96ms +step:1295/1670 train_time:126857ms step_avg:97.96ms +step:1296/1670 train_time:126956ms step_avg:97.96ms +step:1297/1670 train_time:127055ms step_avg:97.96ms +step:1298/1670 train_time:127153ms step_avg:97.96ms +step:1299/1670 train_time:127251ms step_avg:97.96ms +step:1300/1670 train_time:127348ms step_avg:97.96ms +step:1301/1670 train_time:127446ms step_avg:97.96ms +step:1302/1670 train_time:127543ms step_avg:97.96ms +step:1303/1670 train_time:127640ms step_avg:97.96ms +step:1304/1670 train_time:127738ms step_avg:97.96ms +step:1305/1670 train_time:127835ms step_avg:97.96ms +step:1306/1670 train_time:127935ms step_avg:97.96ms +step:1307/1670 train_time:128033ms step_avg:97.96ms +step:1308/1670 train_time:128131ms step_avg:97.96ms +step:1309/1670 train_time:128228ms step_avg:97.96ms +step:1310/1670 train_time:128326ms step_avg:97.96ms +step:1311/1670 train_time:128423ms step_avg:97.96ms +step:1312/1670 train_time:128520ms step_avg:97.96ms +step:1313/1670 train_time:128617ms step_avg:97.96ms +step:1314/1670 train_time:128715ms step_avg:97.96ms +step:1315/1670 train_time:128815ms step_avg:97.96ms +step:1316/1670 train_time:128914ms step_avg:97.96ms +step:1317/1670 train_time:129012ms step_avg:97.96ms +step:1318/1670 train_time:129110ms step_avg:97.96ms +step:1319/1670 train_time:129208ms step_avg:97.96ms +step:1320/1670 train_time:129305ms step_avg:97.96ms +step:1321/1670 train_time:129403ms step_avg:97.96ms +step:1322/1670 train_time:129499ms step_avg:97.96ms +step:1323/1670 train_time:129597ms step_avg:97.96ms +step:1324/1670 train_time:129695ms step_avg:97.96ms +step:1325/1670 train_time:129793ms step_avg:97.96ms +step:1326/1670 train_time:129892ms step_avg:97.96ms +step:1327/1670 train_time:129990ms step_avg:97.96ms +step:1328/1670 train_time:130088ms step_avg:97.96ms +step:1329/1670 train_time:130185ms step_avg:97.96ms +step:1330/1670 train_time:130282ms step_avg:97.96ms +step:1331/1670 train_time:130380ms step_avg:97.96ms +step:1332/1670 train_time:130477ms step_avg:97.96ms +step:1333/1670 train_time:130575ms step_avg:97.96ms +step:1334/1670 train_time:130673ms step_avg:97.96ms +step:1335/1670 train_time:130772ms step_avg:97.96ms +step:1336/1670 train_time:130870ms step_avg:97.96ms +step:1337/1670 train_time:130968ms step_avg:97.96ms +step:1338/1670 train_time:131066ms step_avg:97.96ms +step:1339/1670 train_time:131164ms step_avg:97.96ms +step:1340/1670 train_time:131261ms step_avg:97.96ms +step:1341/1670 train_time:131359ms step_avg:97.96ms +step:1342/1670 train_time:131457ms step_avg:97.96ms +step:1343/1670 train_time:131555ms step_avg:97.96ms +step:1344/1670 train_time:131653ms step_avg:97.96ms +step:1345/1670 train_time:131751ms step_avg:97.96ms +step:1346/1670 train_time:131849ms step_avg:97.96ms +step:1347/1670 train_time:131947ms step_avg:97.96ms +step:1348/1670 train_time:132045ms step_avg:97.96ms +step:1349/1670 train_time:132143ms step_avg:97.96ms +step:1350/1670 train_time:132240ms step_avg:97.96ms +step:1351/1670 train_time:132337ms step_avg:97.95ms +step:1352/1670 train_time:132435ms step_avg:97.95ms +step:1353/1670 train_time:132533ms step_avg:97.96ms +step:1354/1670 train_time:132632ms step_avg:97.96ms +step:1355/1670 train_time:132730ms step_avg:97.96ms +step:1356/1670 train_time:132828ms step_avg:97.96ms +step:1357/1670 train_time:132925ms step_avg:97.96ms +step:1358/1670 train_time:133023ms step_avg:97.96ms +step:1359/1670 train_time:133120ms step_avg:97.95ms +step:1360/1670 train_time:133218ms step_avg:97.95ms +step:1361/1670 train_time:133316ms step_avg:97.95ms +step:1362/1670 train_time:133414ms step_avg:97.95ms +step:1363/1670 train_time:133513ms step_avg:97.96ms +step:1364/1670 train_time:133612ms step_avg:97.96ms +step:1365/1670 train_time:133709ms step_avg:97.96ms +step:1366/1670 train_time:133806ms step_avg:97.95ms +step:1367/1670 train_time:133903ms step_avg:97.95ms +step:1368/1670 train_time:134000ms step_avg:97.95ms +step:1369/1670 train_time:134098ms step_avg:97.95ms +step:1370/1670 train_time:134196ms step_avg:97.95ms +step:1371/1670 train_time:134294ms step_avg:97.95ms +step:1372/1670 train_time:134392ms step_avg:97.95ms +step:1373/1670 train_time:134490ms step_avg:97.95ms +step:1374/1670 train_time:134589ms step_avg:97.95ms +step:1375/1670 train_time:134687ms step_avg:97.95ms +step:1375/1670 val_loss:3.3433 train_time:134784ms step_avg:98.02ms +step:1376/1670 train_time:134806ms step_avg:97.97ms +step:1377/1670 train_time:134891ms step_avg:97.96ms +step:1378/1670 train_time:134991ms step_avg:97.96ms +step:1379/1670 train_time:135087ms step_avg:97.96ms +step:1380/1670 train_time:135185ms step_avg:97.96ms +step:1381/1670 train_time:135283ms step_avg:97.96ms +step:1382/1670 train_time:135382ms step_avg:97.96ms +step:1383/1670 train_time:135478ms step_avg:97.96ms +step:1384/1670 train_time:135575ms step_avg:97.96ms +step:1385/1670 train_time:135671ms step_avg:97.96ms +step:1386/1670 train_time:135770ms step_avg:97.96ms +step:1387/1670 train_time:135870ms step_avg:97.96ms +step:1388/1670 train_time:135970ms step_avg:97.96ms +step:1389/1670 train_time:136068ms step_avg:97.96ms +step:1390/1670 train_time:136166ms step_avg:97.96ms +step:1391/1670 train_time:136264ms step_avg:97.96ms +step:1392/1670 train_time:136362ms step_avg:97.96ms +step:1393/1670 train_time:136459ms step_avg:97.96ms +step:1394/1670 train_time:136556ms step_avg:97.96ms +step:1395/1670 train_time:136653ms step_avg:97.96ms +step:1396/1670 train_time:136750ms step_avg:97.96ms +step:1397/1670 train_time:136849ms step_avg:97.96ms +step:1398/1670 train_time:136947ms step_avg:97.96ms +step:1399/1670 train_time:137046ms step_avg:97.96ms +step:1400/1670 train_time:137143ms step_avg:97.96ms +step:1401/1670 train_time:137240ms step_avg:97.96ms +step:1402/1670 train_time:137339ms step_avg:97.96ms +step:1403/1670 train_time:137436ms step_avg:97.96ms +step:1404/1670 train_time:137533ms step_avg:97.96ms +step:1405/1670 train_time:137630ms step_avg:97.96ms +step:1406/1670 train_time:137728ms step_avg:97.96ms +step:1407/1670 train_time:137827ms step_avg:97.96ms +step:1408/1670 train_time:137927ms step_avg:97.96ms +step:1409/1670 train_time:138026ms step_avg:97.96ms +step:1410/1670 train_time:138124ms step_avg:97.96ms +step:1411/1670 train_time:138222ms step_avg:97.96ms +step:1412/1670 train_time:138320ms step_avg:97.96ms +step:1413/1670 train_time:138418ms step_avg:97.96ms +step:1414/1670 train_time:138516ms step_avg:97.96ms +step:1415/1670 train_time:138614ms step_avg:97.96ms +step:1416/1670 train_time:138710ms step_avg:97.96ms +step:1417/1670 train_time:138808ms step_avg:97.96ms +step:1418/1670 train_time:138906ms step_avg:97.96ms +step:1419/1670 train_time:139006ms step_avg:97.96ms +step:1420/1670 train_time:139106ms step_avg:97.96ms +step:1421/1670 train_time:139205ms step_avg:97.96ms +step:1422/1670 train_time:139303ms step_avg:97.96ms +step:1423/1670 train_time:139401ms step_avg:97.96ms +step:1424/1670 train_time:139500ms step_avg:97.96ms +step:1425/1670 train_time:139599ms step_avg:97.96ms +step:1426/1670 train_time:139698ms step_avg:97.96ms +step:1427/1670 train_time:139796ms step_avg:97.97ms +step:1428/1670 train_time:139894ms step_avg:97.96ms +step:1429/1670 train_time:139992ms step_avg:97.96ms +step:1430/1670 train_time:140090ms step_avg:97.96ms +step:1431/1670 train_time:140188ms step_avg:97.96ms +step:1432/1670 train_time:140286ms step_avg:97.96ms +step:1433/1670 train_time:140383ms step_avg:97.96ms +step:1434/1670 train_time:140480ms step_avg:97.96ms +step:1435/1670 train_time:140578ms step_avg:97.96ms +step:1436/1670 train_time:140676ms step_avg:97.96ms +step:1437/1670 train_time:140774ms step_avg:97.96ms +step:1438/1670 train_time:140871ms step_avg:97.96ms +step:1439/1670 train_time:140968ms step_avg:97.96ms +step:1440/1670 train_time:141066ms step_avg:97.96ms +step:1441/1670 train_time:141165ms step_avg:97.96ms +step:1442/1670 train_time:141263ms step_avg:97.96ms +step:1443/1670 train_time:141360ms step_avg:97.96ms +step:1444/1670 train_time:141458ms step_avg:97.96ms +step:1445/1670 train_time:141556ms step_avg:97.96ms +step:1446/1670 train_time:141653ms step_avg:97.96ms +step:1447/1670 train_time:141751ms step_avg:97.96ms +step:1448/1670 train_time:141849ms step_avg:97.96ms +step:1449/1670 train_time:141947ms step_avg:97.96ms +step:1450/1670 train_time:142046ms step_avg:97.96ms +step:1451/1670 train_time:142143ms step_avg:97.96ms +step:1452/1670 train_time:142241ms step_avg:97.96ms +step:1453/1670 train_time:142339ms step_avg:97.96ms +step:1454/1670 train_time:142436ms step_avg:97.96ms +step:1455/1670 train_time:142533ms step_avg:97.96ms +step:1456/1670 train_time:142630ms step_avg:97.96ms +step:1457/1670 train_time:142728ms step_avg:97.96ms +step:1458/1670 train_time:142827ms step_avg:97.96ms +step:1459/1670 train_time:142925ms step_avg:97.96ms +step:1460/1670 train_time:143024ms step_avg:97.96ms +step:1461/1670 train_time:143124ms step_avg:97.96ms +step:1462/1670 train_time:143222ms step_avg:97.96ms +step:1463/1670 train_time:143319ms step_avg:97.96ms +step:1464/1670 train_time:143417ms step_avg:97.96ms +step:1465/1670 train_time:143515ms step_avg:97.96ms +step:1466/1670 train_time:143613ms step_avg:97.96ms +step:1467/1670 train_time:143710ms step_avg:97.96ms +step:1468/1670 train_time:143808ms step_avg:97.96ms +step:1469/1670 train_time:143906ms step_avg:97.96ms +step:1470/1670 train_time:144004ms step_avg:97.96ms +step:1471/1670 train_time:144101ms step_avg:97.96ms +step:1472/1670 train_time:144199ms step_avg:97.96ms +step:1473/1670 train_time:144297ms step_avg:97.96ms +step:1474/1670 train_time:144394ms step_avg:97.96ms +step:1475/1670 train_time:144492ms step_avg:97.96ms +step:1476/1670 train_time:144589ms step_avg:97.96ms +step:1477/1670 train_time:144687ms step_avg:97.96ms +step:1478/1670 train_time:144786ms step_avg:97.96ms +step:1479/1670 train_time:144885ms step_avg:97.96ms +step:1480/1670 train_time:144983ms step_avg:97.96ms +step:1481/1670 train_time:145080ms step_avg:97.96ms +step:1482/1670 train_time:145178ms step_avg:97.96ms +step:1483/1670 train_time:145276ms step_avg:97.96ms +step:1484/1670 train_time:145373ms step_avg:97.96ms +step:1485/1670 train_time:145639ms step_avg:98.07ms +step:1486/1670 train_time:145720ms step_avg:98.06ms +step:1487/1670 train_time:145817ms step_avg:98.06ms +step:1488/1670 train_time:145913ms step_avg:98.06ms +step:1489/1670 train_time:146010ms step_avg:98.06ms +step:1490/1670 train_time:146107ms step_avg:98.06ms +step:1491/1670 train_time:146204ms step_avg:98.06ms +step:1492/1670 train_time:146301ms step_avg:98.06ms +step:1493/1670 train_time:146398ms step_avg:98.06ms +step:1494/1670 train_time:146496ms step_avg:98.06ms +step:1495/1670 train_time:146598ms step_avg:98.06ms +step:1496/1670 train_time:146698ms step_avg:98.06ms +step:1497/1670 train_time:146796ms step_avg:98.06ms +step:1498/1670 train_time:146893ms step_avg:98.06ms +step:1499/1670 train_time:146990ms step_avg:98.06ms +step:1500/1670 train_time:147088ms step_avg:98.06ms +step:1500/1670 val_loss:3.3107 train_time:147184ms step_avg:98.12ms +step:1501/1670 train_time:147205ms step_avg:98.07ms +step:1502/1670 train_time:147289ms step_avg:98.06ms +step:1503/1670 train_time:147389ms step_avg:98.06ms +step:1504/1670 train_time:147486ms step_avg:98.06ms +step:1505/1670 train_time:147583ms step_avg:98.06ms +step:1506/1670 train_time:147680ms step_avg:98.06ms +step:1507/1670 train_time:147776ms step_avg:98.06ms +step:1508/1670 train_time:147873ms step_avg:98.06ms +step:1509/1670 train_time:147971ms step_avg:98.06ms +step:1510/1670 train_time:148069ms step_avg:98.06ms +step:1511/1670 train_time:148169ms step_avg:98.06ms +step:1512/1670 train_time:148269ms step_avg:98.06ms +step:1513/1670 train_time:148369ms step_avg:98.06ms +step:1514/1670 train_time:148468ms step_avg:98.06ms +step:1515/1670 train_time:148566ms step_avg:98.06ms +step:1516/1670 train_time:148665ms step_avg:98.06ms +step:1517/1670 train_time:148763ms step_avg:98.06ms +step:1518/1670 train_time:148860ms step_avg:98.06ms +step:1519/1670 train_time:148958ms step_avg:98.06ms +step:1520/1670 train_time:149055ms step_avg:98.06ms +step:1521/1670 train_time:149153ms step_avg:98.06ms +step:1522/1670 train_time:149251ms step_avg:98.06ms +step:1523/1670 train_time:149349ms step_avg:98.06ms +step:1524/1670 train_time:149446ms step_avg:98.06ms +step:1525/1670 train_time:149545ms step_avg:98.06ms +step:1526/1670 train_time:149643ms step_avg:98.06ms +step:1527/1670 train_time:149740ms step_avg:98.06ms +step:1528/1670 train_time:149838ms step_avg:98.06ms +step:1529/1670 train_time:149935ms step_avg:98.06ms +step:1530/1670 train_time:150032ms step_avg:98.06ms +step:1531/1670 train_time:150130ms step_avg:98.06ms +step:1532/1670 train_time:150229ms step_avg:98.06ms +step:1533/1670 train_time:150328ms step_avg:98.06ms +step:1534/1670 train_time:150427ms step_avg:98.06ms +step:1535/1670 train_time:150525ms step_avg:98.06ms +step:1536/1670 train_time:150623ms step_avg:98.06ms +step:1537/1670 train_time:150721ms step_avg:98.06ms +step:1538/1670 train_time:150819ms step_avg:98.06ms +step:1539/1670 train_time:150916ms step_avg:98.06ms +step:1540/1670 train_time:151013ms step_avg:98.06ms +step:1541/1670 train_time:151110ms step_avg:98.06ms +step:1542/1670 train_time:151208ms step_avg:98.06ms +step:1543/1670 train_time:151308ms step_avg:98.06ms +step:1544/1670 train_time:151406ms step_avg:98.06ms +step:1545/1670 train_time:151504ms step_avg:98.06ms +step:1546/1670 train_time:151602ms step_avg:98.06ms +step:1547/1670 train_time:151701ms step_avg:98.06ms +step:1548/1670 train_time:151800ms step_avg:98.06ms +step:1549/1670 train_time:151898ms step_avg:98.06ms +step:1550/1670 train_time:151995ms step_avg:98.06ms +step:1551/1670 train_time:152093ms step_avg:98.06ms +step:1552/1670 train_time:152190ms step_avg:98.06ms +step:1553/1670 train_time:152289ms step_avg:98.06ms +step:1554/1670 train_time:152387ms step_avg:98.06ms +step:1555/1670 train_time:152484ms step_avg:98.06ms +step:1556/1670 train_time:152582ms step_avg:98.06ms +step:1557/1670 train_time:152679ms step_avg:98.06ms +step:1558/1670 train_time:152777ms step_avg:98.06ms +step:1559/1670 train_time:152875ms step_avg:98.06ms +step:1560/1670 train_time:152973ms step_avg:98.06ms +step:1561/1670 train_time:153070ms step_avg:98.06ms +step:1562/1670 train_time:153168ms step_avg:98.06ms +step:1563/1670 train_time:153267ms step_avg:98.06ms +step:1564/1670 train_time:153365ms step_avg:98.06ms +step:1565/1670 train_time:153464ms step_avg:98.06ms +step:1566/1670 train_time:153562ms step_avg:98.06ms +step:1567/1670 train_time:153659ms step_avg:98.06ms +step:1568/1670 train_time:153756ms step_avg:98.06ms +step:1569/1670 train_time:153853ms step_avg:98.06ms +step:1570/1670 train_time:153951ms step_avg:98.06ms +step:1571/1670 train_time:154049ms step_avg:98.06ms +step:1572/1670 train_time:154148ms step_avg:98.06ms +step:1573/1670 train_time:154246ms step_avg:98.06ms +step:1574/1670 train_time:154344ms step_avg:98.06ms +step:1575/1670 train_time:154441ms step_avg:98.06ms +step:1576/1670 train_time:154539ms step_avg:98.06ms +step:1577/1670 train_time:154636ms step_avg:98.06ms +step:1578/1670 train_time:154733ms step_avg:98.06ms +step:1579/1670 train_time:154831ms step_avg:98.06ms +step:1580/1670 train_time:154929ms step_avg:98.06ms +step:1581/1670 train_time:155029ms step_avg:98.06ms +step:1582/1670 train_time:155128ms step_avg:98.06ms +step:1583/1670 train_time:155228ms step_avg:98.06ms +step:1584/1670 train_time:155326ms step_avg:98.06ms +step:1585/1670 train_time:155425ms step_avg:98.06ms +step:1586/1670 train_time:155523ms step_avg:98.06ms +step:1587/1670 train_time:155622ms step_avg:98.06ms +step:1588/1670 train_time:155720ms step_avg:98.06ms +step:1589/1670 train_time:155818ms step_avg:98.06ms +step:1590/1670 train_time:155915ms step_avg:98.06ms +step:1591/1670 train_time:156013ms step_avg:98.06ms +step:1592/1670 train_time:156111ms step_avg:98.06ms +step:1593/1670 train_time:156208ms step_avg:98.06ms +step:1594/1670 train_time:156307ms step_avg:98.06ms +step:1595/1670 train_time:156405ms step_avg:98.06ms +step:1596/1670 train_time:156504ms step_avg:98.06ms +step:1597/1670 train_time:156602ms step_avg:98.06ms +step:1598/1670 train_time:156699ms step_avg:98.06ms +step:1599/1670 train_time:156796ms step_avg:98.06ms +step:1600/1670 train_time:156894ms step_avg:98.06ms +step:1601/1670 train_time:156991ms step_avg:98.06ms +step:1602/1670 train_time:157089ms step_avg:98.06ms +step:1603/1670 train_time:157187ms step_avg:98.06ms +step:1604/1670 train_time:157285ms step_avg:98.06ms +step:1605/1670 train_time:157384ms step_avg:98.06ms +step:1606/1670 train_time:157482ms step_avg:98.06ms +step:1607/1670 train_time:157579ms step_avg:98.06ms +step:1608/1670 train_time:157676ms step_avg:98.06ms +step:1609/1670 train_time:157774ms step_avg:98.06ms +step:1610/1670 train_time:157872ms step_avg:98.06ms +step:1611/1670 train_time:157970ms step_avg:98.06ms +step:1612/1670 train_time:158068ms step_avg:98.06ms +step:1613/1670 train_time:158166ms step_avg:98.06ms +step:1614/1670 train_time:158265ms step_avg:98.06ms +step:1615/1670 train_time:158363ms step_avg:98.06ms +step:1616/1670 train_time:158461ms step_avg:98.06ms +step:1617/1670 train_time:158559ms step_avg:98.06ms +step:1618/1670 train_time:158656ms step_avg:98.06ms +step:1619/1670 train_time:158754ms step_avg:98.06ms +step:1620/1670 train_time:158851ms step_avg:98.06ms +step:1621/1670 train_time:158949ms step_avg:98.06ms +step:1622/1670 train_time:159047ms step_avg:98.06ms +step:1623/1670 train_time:159145ms step_avg:98.06ms +step:1624/1670 train_time:159243ms step_avg:98.06ms +step:1625/1670 train_time:159342ms step_avg:98.06ms +step:1625/1670 val_loss:3.2839 train_time:159438ms step_avg:98.12ms +step:1626/1670 train_time:159459ms step_avg:98.07ms +step:1627/1670 train_time:159543ms step_avg:98.06ms +step:1628/1670 train_time:159642ms step_avg:98.06ms +step:1629/1670 train_time:159739ms step_avg:98.06ms +step:1630/1670 train_time:159836ms step_avg:98.06ms +step:1631/1670 train_time:159933ms step_avg:98.06ms +step:1632/1670 train_time:160031ms step_avg:98.06ms +step:1633/1670 train_time:160128ms step_avg:98.06ms +step:1634/1670 train_time:160226ms step_avg:98.06ms +step:1635/1670 train_time:160323ms step_avg:98.06ms +step:1636/1670 train_time:160422ms step_avg:98.06ms +step:1637/1670 train_time:160523ms step_avg:98.06ms +step:1638/1670 train_time:160622ms step_avg:98.06ms +step:1639/1670 train_time:160720ms step_avg:98.06ms +step:1640/1670 train_time:160817ms step_avg:98.06ms +step:1641/1670 train_time:160914ms step_avg:98.06ms +step:1642/1670 train_time:161011ms step_avg:98.06ms +step:1643/1670 train_time:161109ms step_avg:98.06ms +step:1644/1670 train_time:161206ms step_avg:98.06ms +step:1645/1670 train_time:161303ms step_avg:98.06ms +step:1646/1670 train_time:161402ms step_avg:98.06ms +step:1647/1670 train_time:161501ms step_avg:98.06ms +step:1648/1670 train_time:161599ms step_avg:98.06ms +step:1649/1670 train_time:161696ms step_avg:98.06ms +step:1650/1670 train_time:161794ms step_avg:98.06ms +step:1651/1670 train_time:161891ms step_avg:98.06ms +step:1652/1670 train_time:161990ms step_avg:98.06ms +step:1653/1670 train_time:162087ms step_avg:98.06ms +step:1654/1670 train_time:162184ms step_avg:98.06ms +step:1655/1670 train_time:162282ms step_avg:98.06ms +step:1656/1670 train_time:162379ms step_avg:98.06ms +step:1657/1670 train_time:162477ms step_avg:98.06ms +step:1658/1670 train_time:162576ms step_avg:98.06ms +step:1659/1670 train_time:162675ms step_avg:98.06ms +step:1660/1670 train_time:162772ms step_avg:98.06ms +step:1661/1670 train_time:162870ms step_avg:98.06ms +step:1662/1670 train_time:162968ms step_avg:98.06ms +step:1663/1670 train_time:163067ms step_avg:98.06ms +step:1664/1670 train_time:163164ms step_avg:98.06ms +step:1665/1670 train_time:163262ms step_avg:98.06ms +step:1666/1670 train_time:163360ms step_avg:98.06ms +step:1667/1670 train_time:163458ms step_avg:98.06ms +step:1668/1670 train_time:163556ms step_avg:98.06ms +step:1669/1670 train_time:163654ms step_avg:98.06ms +step:1670/1670 train_time:163751ms step_avg:98.05ms +step:1670/1670 val_loss:3.2760 train_time:163848ms step_avg:98.11ms +peak memory allocated: 34001 MiB reserved: 49316 MiB diff --git a/records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt b/records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt new file mode 100644 index 000000000..34fc9dc52 --- /dev/null +++ b/records/090325_FA3/831dade9-9b29-43ff-9106-80fc680b3e57.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:48:32 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 41C P0 121W / 700W | 5858MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 40C P0 131W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 40C P0 121W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 122W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 50803 C /usr/bin/python 0MiB | +| 0 N/A N/A 50804 C /usr/bin/python 0MiB | +| 0 N/A N/A 50805 C /usr/bin/python 0MiB | +| 0 N/A N/A 50806 C /usr/bin/python 0MiB | +| 0 N/A N/A 50807 C /usr/bin/python 0MiB | +| 0 N/A N/A 50808 C /usr/bin/python 0MiB | +| 0 N/A N/A 50809 C /usr/bin/python 0MiB | +| 0 N/A N/A 50810 C /usr/bin/python 0MiB | +| 1 N/A N/A 50804 C /usr/bin/python 0MiB | +| 2 N/A N/A 50805 C /usr/bin/python 0MiB | +| 3 N/A N/A 50806 C /usr/bin/python 0MiB | +| 4 N/A N/A 50807 C /usr/bin/python 0MiB | +| 5 N/A N/A 50808 C /usr/bin/python 0MiB | +| 6 N/A N/A 50809 C /usr/bin/python 0MiB | +| 7 N/A N/A 50810 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:366ms step_avg:365.67ms +step:2/1670 train_time:386ms step_avg:193.15ms +step:3/1670 train_time:459ms step_avg:153.08ms +step:4/1670 train_time:553ms step_avg:138.26ms +step:5/1670 train_time:648ms step_avg:129.54ms +step:6/1670 train_time:742ms step_avg:123.73ms +step:7/1670 train_time:838ms step_avg:119.67ms +step:8/1670 train_time:933ms step_avg:116.57ms +step:9/1670 train_time:1028ms step_avg:114.20ms +step:10/1670 train_time:1123ms step_avg:112.27ms +step:11/1670 train_time:1219ms step_avg:110.77ms +step:12/1670 train_time:1316ms step_avg:109.64ms +step:13/1670 train_time:1414ms step_avg:108.77ms +step:14/1670 train_time:1512ms step_avg:108.01ms +step:15/1670 train_time:1608ms step_avg:107.19ms +step:16/1670 train_time:1704ms step_avg:106.47ms +step:17/1670 train_time:1799ms step_avg:105.83ms +step:18/1670 train_time:1894ms step_avg:105.25ms +step:19/1670 train_time:1990ms step_avg:104.72ms +step:20/1670 train_time:2085ms step_avg:104.27ms +step:21/1670 train_time:2181ms step_avg:103.88ms +step:22/1670 train_time:2277ms step_avg:103.50ms +step:23/1670 train_time:2373ms step_avg:103.18ms +step:24/1670 train_time:2470ms step_avg:102.92ms +step:25/1670 train_time:2567ms step_avg:102.68ms +step:26/1670 train_time:2664ms step_avg:102.46ms +step:27/1670 train_time:2761ms step_avg:102.27ms +step:28/1670 train_time:2856ms step_avg:102.02ms +step:29/1670 train_time:2952ms step_avg:101.78ms +step:30/1670 train_time:3048ms step_avg:101.60ms +step:31/1670 train_time:3143ms step_avg:101.40ms +step:32/1670 train_time:3240ms step_avg:101.25ms +step:33/1670 train_time:3335ms step_avg:101.07ms +step:34/1670 train_time:3432ms step_avg:100.93ms +step:35/1670 train_time:3530ms step_avg:100.86ms +step:36/1670 train_time:3627ms step_avg:100.75ms +step:37/1670 train_time:3724ms step_avg:100.65ms +step:38/1670 train_time:3821ms step_avg:100.55ms +step:39/1670 train_time:3917ms step_avg:100.42ms +step:40/1670 train_time:4012ms step_avg:100.31ms +step:41/1670 train_time:4108ms step_avg:100.19ms +step:42/1670 train_time:4204ms step_avg:100.09ms +step:43/1670 train_time:4300ms step_avg:100.01ms +step:44/1670 train_time:4397ms step_avg:99.92ms +step:45/1670 train_time:4493ms step_avg:99.84ms +step:46/1670 train_time:4589ms step_avg:99.75ms +step:47/1670 train_time:4686ms step_avg:99.70ms +step:48/1670 train_time:4783ms step_avg:99.65ms +step:49/1670 train_time:4880ms step_avg:99.59ms +step:50/1670 train_time:4976ms step_avg:99.52ms +step:51/1670 train_time:5071ms step_avg:99.43ms +step:52/1670 train_time:5166ms step_avg:99.35ms +step:53/1670 train_time:5262ms step_avg:99.29ms +step:54/1670 train_time:5359ms step_avg:99.23ms +step:55/1670 train_time:5454ms step_avg:99.16ms +step:56/1670 train_time:5549ms step_avg:99.09ms +step:57/1670 train_time:5645ms step_avg:99.04ms +step:58/1670 train_time:5741ms step_avg:98.98ms +step:59/1670 train_time:5837ms step_avg:98.93ms +step:60/1670 train_time:5933ms step_avg:98.88ms +step:61/1670 train_time:6029ms step_avg:98.83ms +step:62/1670 train_time:6125ms step_avg:98.79ms +step:63/1670 train_time:6221ms step_avg:98.74ms +step:64/1670 train_time:6316ms step_avg:98.69ms +step:65/1670 train_time:6412ms step_avg:98.64ms +step:66/1670 train_time:6508ms step_avg:98.60ms +step:67/1670 train_time:6603ms step_avg:98.56ms +step:68/1670 train_time:6700ms step_avg:98.53ms +step:69/1670 train_time:6795ms step_avg:98.48ms +step:70/1670 train_time:6891ms step_avg:98.45ms +step:71/1670 train_time:6987ms step_avg:98.41ms +step:72/1670 train_time:7083ms step_avg:98.37ms +step:73/1670 train_time:7179ms step_avg:98.34ms +step:74/1670 train_time:7275ms step_avg:98.30ms +step:75/1670 train_time:7370ms step_avg:98.27ms +step:76/1670 train_time:7466ms step_avg:98.23ms +step:77/1670 train_time:7561ms step_avg:98.20ms +step:78/1670 train_time:7657ms step_avg:98.17ms +step:79/1670 train_time:7753ms step_avg:98.13ms +step:80/1670 train_time:7849ms step_avg:98.11ms +step:81/1670 train_time:7946ms step_avg:98.09ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8137ms step_avg:98.03ms +step:84/1670 train_time:8232ms step_avg:98.00ms +step:85/1670 train_time:8328ms step_avg:97.97ms +step:86/1670 train_time:8424ms step_avg:97.96ms +step:87/1670 train_time:8520ms step_avg:97.93ms +step:88/1670 train_time:8616ms step_avg:97.91ms +step:89/1670 train_time:8711ms step_avg:97.88ms +step:90/1670 train_time:8807ms step_avg:97.85ms +step:91/1670 train_time:8903ms step_avg:97.84ms +step:92/1670 train_time:8999ms step_avg:97.82ms +step:93/1670 train_time:9095ms step_avg:97.80ms +step:94/1670 train_time:9190ms step_avg:97.77ms +step:95/1670 train_time:9287ms step_avg:97.75ms +step:96/1670 train_time:9383ms step_avg:97.74ms +step:97/1670 train_time:9479ms step_avg:97.73ms +step:98/1670 train_time:9575ms step_avg:97.70ms +step:99/1670 train_time:9670ms step_avg:97.68ms +step:100/1670 train_time:9766ms step_avg:97.66ms +step:101/1670 train_time:9863ms step_avg:97.65ms +step:102/1670 train_time:9959ms step_avg:97.64ms +step:103/1670 train_time:10055ms step_avg:97.62ms +step:104/1670 train_time:10150ms step_avg:97.59ms +step:105/1670 train_time:10245ms step_avg:97.57ms +step:106/1670 train_time:10340ms step_avg:97.55ms +step:107/1670 train_time:10436ms step_avg:97.53ms +step:108/1670 train_time:10531ms step_avg:97.51ms +step:109/1670 train_time:10628ms step_avg:97.50ms +step:110/1670 train_time:10724ms step_avg:97.49ms +step:111/1670 train_time:10819ms step_avg:97.47ms +step:112/1670 train_time:10915ms step_avg:97.45ms +step:113/1670 train_time:11010ms step_avg:97.44ms +step:114/1670 train_time:11107ms step_avg:97.43ms +step:115/1670 train_time:11202ms step_avg:97.41ms +step:116/1670 train_time:11298ms step_avg:97.39ms +step:117/1670 train_time:11393ms step_avg:97.38ms +step:118/1670 train_time:11488ms step_avg:97.36ms +step:119/1670 train_time:11584ms step_avg:97.34ms +step:120/1670 train_time:11680ms step_avg:97.33ms +step:121/1670 train_time:11776ms step_avg:97.32ms +step:122/1670 train_time:11872ms step_avg:97.31ms +step:123/1670 train_time:11967ms step_avg:97.29ms +step:124/1670 train_time:12063ms step_avg:97.28ms +step:125/1670 train_time:12159ms step_avg:97.27ms +step:125/1670 val_loss:4.3294 train_time:12253ms step_avg:98.03ms +step:126/1670 train_time:12275ms step_avg:97.42ms +step:127/1670 train_time:12357ms step_avg:97.30ms +step:128/1670 train_time:12464ms step_avg:97.37ms +step:129/1670 train_time:12560ms step_avg:97.36ms +step:130/1670 train_time:12655ms step_avg:97.34ms +step:131/1670 train_time:12749ms step_avg:97.32ms +step:132/1670 train_time:12843ms step_avg:97.30ms +step:133/1670 train_time:12938ms step_avg:97.28ms +step:134/1670 train_time:13033ms step_avg:97.26ms +step:135/1670 train_time:13128ms step_avg:97.24ms +step:136/1670 train_time:13222ms step_avg:97.22ms +step:137/1670 train_time:13319ms step_avg:97.22ms +step:138/1670 train_time:13420ms step_avg:97.25ms +step:139/1670 train_time:13518ms step_avg:97.25ms +step:140/1670 train_time:13615ms step_avg:97.25ms +step:141/1670 train_time:13711ms step_avg:97.24ms +step:142/1670 train_time:13807ms step_avg:97.23ms +step:143/1670 train_time:13901ms step_avg:97.21ms +step:144/1670 train_time:13996ms step_avg:97.19ms +step:145/1670 train_time:14091ms step_avg:97.18ms +step:146/1670 train_time:14187ms step_avg:97.17ms +step:147/1670 train_time:14282ms step_avg:97.15ms +step:148/1670 train_time:14378ms step_avg:97.15ms +step:149/1670 train_time:14475ms step_avg:97.15ms +step:150/1670 train_time:14571ms step_avg:97.14ms +step:151/1670 train_time:14667ms step_avg:97.13ms +step:152/1670 train_time:14763ms step_avg:97.12ms +step:153/1670 train_time:14858ms step_avg:97.11ms +step:154/1670 train_time:14954ms step_avg:97.10ms +step:155/1670 train_time:15049ms step_avg:97.09ms +step:156/1670 train_time:15144ms step_avg:97.08ms +step:157/1670 train_time:15239ms step_avg:97.06ms +step:158/1670 train_time:15335ms step_avg:97.06ms +step:159/1670 train_time:15432ms step_avg:97.06ms +step:160/1670 train_time:15529ms step_avg:97.05ms +step:161/1670 train_time:15625ms step_avg:97.05ms +step:162/1670 train_time:15721ms step_avg:97.05ms +step:163/1670 train_time:15817ms step_avg:97.04ms +step:164/1670 train_time:15912ms step_avg:97.02ms +step:165/1670 train_time:16007ms step_avg:97.01ms +step:166/1670 train_time:16102ms step_avg:97.00ms +step:167/1670 train_time:16197ms step_avg:96.99ms +step:168/1670 train_time:16293ms step_avg:96.98ms +step:169/1670 train_time:16389ms step_avg:96.98ms +step:170/1670 train_time:16485ms step_avg:96.97ms +step:171/1670 train_time:16580ms step_avg:96.96ms +step:172/1670 train_time:16678ms step_avg:96.96ms +step:173/1670 train_time:16774ms step_avg:96.96ms +step:174/1670 train_time:16869ms step_avg:96.95ms +step:175/1670 train_time:16964ms step_avg:96.94ms +step:176/1670 train_time:17059ms step_avg:96.93ms +step:177/1670 train_time:17154ms step_avg:96.92ms +step:178/1670 train_time:17250ms step_avg:96.91ms +step:179/1670 train_time:17345ms step_avg:96.90ms +step:180/1670 train_time:17441ms step_avg:96.89ms +step:181/1670 train_time:17537ms step_avg:96.89ms +step:182/1670 train_time:17634ms step_avg:96.89ms +step:183/1670 train_time:17730ms step_avg:96.89ms +step:184/1670 train_time:17826ms step_avg:96.88ms +step:185/1670 train_time:17922ms step_avg:96.87ms +step:186/1670 train_time:18017ms step_avg:96.86ms +step:187/1670 train_time:18112ms step_avg:96.86ms +step:188/1670 train_time:18207ms step_avg:96.85ms +step:189/1670 train_time:18303ms step_avg:96.84ms +step:190/1670 train_time:18399ms step_avg:96.84ms +step:191/1670 train_time:18495ms step_avg:96.83ms +step:192/1670 train_time:18591ms step_avg:96.83ms +step:193/1670 train_time:18688ms step_avg:96.83ms +step:194/1670 train_time:18783ms step_avg:96.82ms +step:195/1670 train_time:18879ms step_avg:96.81ms +step:196/1670 train_time:18975ms step_avg:96.81ms +step:197/1670 train_time:19070ms step_avg:96.80ms +step:198/1670 train_time:19165ms step_avg:96.79ms +step:199/1670 train_time:19261ms step_avg:96.79ms +step:200/1670 train_time:19356ms step_avg:96.78ms +step:201/1670 train_time:19453ms step_avg:96.78ms +step:202/1670 train_time:19549ms step_avg:96.78ms +step:203/1670 train_time:19645ms step_avg:96.77ms +step:204/1670 train_time:19741ms step_avg:96.77ms +step:205/1670 train_time:19837ms step_avg:96.77ms +step:206/1670 train_time:19933ms step_avg:96.76ms +step:207/1670 train_time:20029ms step_avg:96.76ms +step:208/1670 train_time:20125ms step_avg:96.75ms +step:209/1670 train_time:20219ms step_avg:96.74ms +step:210/1670 train_time:20315ms step_avg:96.74ms +step:211/1670 train_time:20412ms step_avg:96.74ms +step:212/1670 train_time:20508ms step_avg:96.74ms +step:213/1670 train_time:20795ms step_avg:97.63ms +step:214/1670 train_time:20891ms step_avg:97.62ms +step:215/1670 train_time:20984ms step_avg:97.60ms +step:216/1670 train_time:21078ms step_avg:97.59ms +step:217/1670 train_time:21173ms step_avg:97.57ms +step:218/1670 train_time:21269ms step_avg:97.56ms +step:219/1670 train_time:21363ms step_avg:97.55ms +step:220/1670 train_time:21457ms step_avg:97.53ms +step:221/1670 train_time:21552ms step_avg:97.52ms +step:222/1670 train_time:21647ms step_avg:97.51ms +step:223/1670 train_time:21743ms step_avg:97.50ms +step:224/1670 train_time:21841ms step_avg:97.50ms +step:225/1670 train_time:21939ms step_avg:97.51ms +step:226/1670 train_time:22036ms step_avg:97.50ms +step:227/1670 train_time:22131ms step_avg:97.49ms +step:228/1670 train_time:22226ms step_avg:97.48ms +step:229/1670 train_time:22321ms step_avg:97.47ms +step:230/1670 train_time:22416ms step_avg:97.46ms +step:231/1670 train_time:22511ms step_avg:97.45ms +step:232/1670 train_time:22607ms step_avg:97.44ms +step:233/1670 train_time:22702ms step_avg:97.43ms +step:234/1670 train_time:22798ms step_avg:97.43ms +step:235/1670 train_time:22896ms step_avg:97.43ms +step:236/1670 train_time:22993ms step_avg:97.43ms +step:237/1670 train_time:23089ms step_avg:97.42ms +step:238/1670 train_time:23184ms step_avg:97.41ms +step:239/1670 train_time:23279ms step_avg:97.40ms +step:240/1670 train_time:23375ms step_avg:97.39ms +step:241/1670 train_time:23470ms step_avg:97.39ms +step:242/1670 train_time:23565ms step_avg:97.37ms +step:243/1670 train_time:23660ms step_avg:97.37ms +step:244/1670 train_time:23756ms step_avg:97.36ms +step:245/1670 train_time:23852ms step_avg:97.35ms +step:246/1670 train_time:23949ms step_avg:97.35ms +step:247/1670 train_time:24044ms step_avg:97.35ms +step:248/1670 train_time:24140ms step_avg:97.34ms +step:249/1670 train_time:24236ms step_avg:97.33ms +step:250/1670 train_time:24333ms step_avg:97.33ms +step:250/1670 val_loss:3.9718 train_time:24427ms step_avg:97.71ms +step:251/1670 train_time:24450ms step_avg:97.41ms +step:252/1670 train_time:24530ms step_avg:97.34ms +step:253/1670 train_time:24633ms step_avg:97.36ms +step:254/1670 train_time:24730ms step_avg:97.36ms +step:255/1670 train_time:24825ms step_avg:97.35ms +step:256/1670 train_time:24919ms step_avg:97.34ms +step:257/1670 train_time:25014ms step_avg:97.33ms +step:258/1670 train_time:25109ms step_avg:97.32ms +step:259/1670 train_time:25204ms step_avg:97.31ms +step:260/1670 train_time:25298ms step_avg:97.30ms +step:261/1670 train_time:25393ms step_avg:97.29ms +step:262/1670 train_time:25491ms step_avg:97.30ms +step:263/1670 train_time:25591ms step_avg:97.31ms +step:264/1670 train_time:25690ms step_avg:97.31ms +step:265/1670 train_time:25787ms step_avg:97.31ms +step:266/1670 train_time:25883ms step_avg:97.30ms +step:267/1670 train_time:25978ms step_avg:97.30ms +step:268/1670 train_time:26073ms step_avg:97.29ms +step:269/1670 train_time:26168ms step_avg:97.28ms +step:270/1670 train_time:26263ms step_avg:97.27ms +step:271/1670 train_time:26358ms step_avg:97.26ms +step:272/1670 train_time:26453ms step_avg:97.25ms +step:273/1670 train_time:26551ms step_avg:97.25ms +step:274/1670 train_time:26648ms step_avg:97.26ms +step:275/1670 train_time:26745ms step_avg:97.26ms +step:276/1670 train_time:26841ms step_avg:97.25ms +step:277/1670 train_time:26936ms step_avg:97.24ms +step:278/1670 train_time:27031ms step_avg:97.23ms +step:279/1670 train_time:27126ms step_avg:97.23ms +step:280/1670 train_time:27221ms step_avg:97.22ms +step:281/1670 train_time:27316ms step_avg:97.21ms +step:282/1670 train_time:27411ms step_avg:97.20ms +step:283/1670 train_time:27508ms step_avg:97.20ms +step:284/1670 train_time:27604ms step_avg:97.20ms +step:285/1670 train_time:27700ms step_avg:97.19ms +step:286/1670 train_time:27796ms step_avg:97.19ms +step:287/1670 train_time:27892ms step_avg:97.18ms +step:288/1670 train_time:27988ms step_avg:97.18ms +step:289/1670 train_time:28083ms step_avg:97.17ms +step:290/1670 train_time:28179ms step_avg:97.17ms +step:291/1670 train_time:28273ms step_avg:97.16ms +step:292/1670 train_time:28369ms step_avg:97.15ms +step:293/1670 train_time:28464ms step_avg:97.15ms +step:294/1670 train_time:28560ms step_avg:97.14ms +step:295/1670 train_time:28655ms step_avg:97.14ms +step:296/1670 train_time:28752ms step_avg:97.13ms +step:297/1670 train_time:28849ms step_avg:97.13ms +step:298/1670 train_time:28945ms step_avg:97.13ms +step:299/1670 train_time:29040ms step_avg:97.12ms +step:300/1670 train_time:29135ms step_avg:97.12ms +step:301/1670 train_time:29230ms step_avg:97.11ms +step:302/1670 train_time:29325ms step_avg:97.10ms +step:303/1670 train_time:29421ms step_avg:97.10ms +step:304/1670 train_time:29516ms step_avg:97.09ms +step:305/1670 train_time:29612ms step_avg:97.09ms +step:306/1670 train_time:29708ms step_avg:97.08ms +step:307/1670 train_time:29804ms step_avg:97.08ms +step:308/1670 train_time:29900ms step_avg:97.08ms +step:309/1670 train_time:29995ms step_avg:97.07ms +step:310/1670 train_time:30091ms step_avg:97.07ms +step:311/1670 train_time:30187ms step_avg:97.06ms +step:312/1670 train_time:30283ms step_avg:97.06ms +step:313/1670 train_time:30379ms step_avg:97.06ms +step:314/1670 train_time:30474ms step_avg:97.05ms +step:315/1670 train_time:30570ms step_avg:97.05ms +step:316/1670 train_time:30666ms step_avg:97.04ms +step:317/1670 train_time:30762ms step_avg:97.04ms +step:318/1670 train_time:30857ms step_avg:97.04ms +step:319/1670 train_time:30953ms step_avg:97.03ms +step:320/1670 train_time:31050ms step_avg:97.03ms +step:321/1670 train_time:31145ms step_avg:97.03ms +step:322/1670 train_time:31241ms step_avg:97.02ms +step:323/1670 train_time:31337ms step_avg:97.02ms +step:324/1670 train_time:31432ms step_avg:97.01ms +step:325/1670 train_time:31527ms step_avg:97.01ms +step:326/1670 train_time:31623ms step_avg:97.00ms +step:327/1670 train_time:31719ms step_avg:97.00ms +step:328/1670 train_time:31814ms step_avg:96.99ms +step:329/1670 train_time:31910ms step_avg:96.99ms +step:330/1670 train_time:32006ms step_avg:96.99ms +step:331/1670 train_time:32102ms step_avg:96.98ms +step:332/1670 train_time:32197ms step_avg:96.98ms +step:333/1670 train_time:32292ms step_avg:96.97ms +step:334/1670 train_time:32389ms step_avg:96.97ms +step:335/1670 train_time:32485ms step_avg:96.97ms +step:336/1670 train_time:32580ms step_avg:96.96ms +step:337/1670 train_time:32675ms step_avg:96.96ms +step:338/1670 train_time:32772ms step_avg:96.96ms +step:339/1670 train_time:32868ms step_avg:96.96ms +step:340/1670 train_time:32964ms step_avg:96.95ms +step:341/1670 train_time:33060ms step_avg:96.95ms +step:342/1670 train_time:33155ms step_avg:96.94ms +step:343/1670 train_time:33251ms step_avg:96.94ms +step:344/1670 train_time:33347ms step_avg:96.94ms +step:345/1670 train_time:33443ms step_avg:96.94ms +step:346/1670 train_time:33538ms step_avg:96.93ms +step:347/1670 train_time:33634ms step_avg:96.93ms +step:348/1670 train_time:33729ms step_avg:96.92ms +step:349/1670 train_time:33825ms step_avg:96.92ms +step:350/1670 train_time:33921ms step_avg:96.92ms +step:351/1670 train_time:34016ms step_avg:96.91ms +step:352/1670 train_time:34112ms step_avg:96.91ms +step:353/1670 train_time:34208ms step_avg:96.91ms +step:354/1670 train_time:34304ms step_avg:96.90ms +step:355/1670 train_time:34399ms step_avg:96.90ms +step:356/1670 train_time:34495ms step_avg:96.89ms +step:357/1670 train_time:34591ms step_avg:96.89ms +step:358/1670 train_time:34686ms step_avg:96.89ms +step:359/1670 train_time:34781ms step_avg:96.88ms +step:360/1670 train_time:34876ms step_avg:96.88ms +step:361/1670 train_time:34972ms step_avg:96.88ms +step:362/1670 train_time:35068ms step_avg:96.87ms +step:363/1670 train_time:35165ms step_avg:96.87ms +step:364/1670 train_time:35260ms step_avg:96.87ms +step:365/1670 train_time:35356ms step_avg:96.87ms +step:366/1670 train_time:35451ms step_avg:96.86ms +step:367/1670 train_time:35547ms step_avg:96.86ms +step:368/1670 train_time:35643ms step_avg:96.86ms +step:369/1670 train_time:35738ms step_avg:96.85ms +step:370/1670 train_time:35833ms step_avg:96.85ms +step:371/1670 train_time:35929ms step_avg:96.84ms +step:372/1670 train_time:36025ms step_avg:96.84ms +step:373/1670 train_time:36121ms step_avg:96.84ms +step:374/1670 train_time:36216ms step_avg:96.83ms +step:375/1670 train_time:36311ms step_avg:96.83ms +step:375/1670 val_loss:3.8133 train_time:36407ms step_avg:97.08ms +step:376/1670 train_time:36427ms step_avg:96.88ms +step:377/1670 train_time:36509ms step_avg:96.84ms +step:378/1670 train_time:36610ms step_avg:96.85ms +step:379/1670 train_time:36705ms step_avg:96.85ms +step:380/1670 train_time:36801ms step_avg:96.84ms +step:381/1670 train_time:36896ms step_avg:96.84ms +step:382/1670 train_time:36991ms step_avg:96.83ms +step:383/1670 train_time:37086ms step_avg:96.83ms +step:384/1670 train_time:37181ms step_avg:96.83ms +step:385/1670 train_time:37276ms step_avg:96.82ms +step:386/1670 train_time:37370ms step_avg:96.81ms +step:387/1670 train_time:37467ms step_avg:96.81ms +step:388/1670 train_time:37566ms step_avg:96.82ms +step:389/1670 train_time:37663ms step_avg:96.82ms +step:390/1670 train_time:37760ms step_avg:96.82ms +step:391/1670 train_time:37856ms step_avg:96.82ms +step:392/1670 train_time:37951ms step_avg:96.81ms +step:393/1670 train_time:38046ms step_avg:96.81ms +step:394/1670 train_time:38140ms step_avg:96.80ms +step:395/1670 train_time:38235ms step_avg:96.80ms +step:396/1670 train_time:38330ms step_avg:96.79ms +step:397/1670 train_time:38425ms step_avg:96.79ms +step:398/1670 train_time:38522ms step_avg:96.79ms +step:399/1670 train_time:38620ms step_avg:96.79ms +step:400/1670 train_time:38716ms step_avg:96.79ms +step:401/1670 train_time:38812ms step_avg:96.79ms +step:402/1670 train_time:38907ms step_avg:96.78ms +step:403/1670 train_time:39003ms step_avg:96.78ms +step:404/1670 train_time:39098ms step_avg:96.78ms +step:405/1670 train_time:39194ms step_avg:96.77ms +step:406/1670 train_time:39288ms step_avg:96.77ms +step:407/1670 train_time:39384ms step_avg:96.77ms +step:408/1670 train_time:39480ms step_avg:96.76ms +step:409/1670 train_time:39576ms step_avg:96.76ms +step:410/1670 train_time:39672ms step_avg:96.76ms +step:411/1670 train_time:39768ms step_avg:96.76ms +step:412/1670 train_time:39864ms step_avg:96.76ms +step:413/1670 train_time:39961ms step_avg:96.76ms +step:414/1670 train_time:40057ms step_avg:96.76ms +step:415/1670 train_time:40152ms step_avg:96.75ms +step:416/1670 train_time:40247ms step_avg:96.75ms +step:417/1670 train_time:40342ms step_avg:96.74ms +step:418/1670 train_time:40438ms step_avg:96.74ms +step:419/1670 train_time:40534ms step_avg:96.74ms +step:420/1670 train_time:40629ms step_avg:96.73ms +step:421/1670 train_time:40725ms step_avg:96.73ms +step:422/1670 train_time:40823ms step_avg:96.74ms +step:423/1670 train_time:40920ms step_avg:96.74ms +step:424/1670 train_time:41016ms step_avg:96.74ms +step:425/1670 train_time:41309ms step_avg:97.20ms +step:426/1670 train_time:41449ms step_avg:97.30ms +step:427/1670 train_time:41542ms step_avg:97.29ms +step:428/1670 train_time:41637ms step_avg:97.28ms +step:429/1670 train_time:41732ms step_avg:97.28ms +step:430/1670 train_time:41826ms step_avg:97.27ms +step:431/1670 train_time:41921ms step_avg:97.26ms +step:432/1670 train_time:42015ms step_avg:97.26ms +step:433/1670 train_time:42110ms step_avg:97.25ms +step:434/1670 train_time:42204ms step_avg:97.24ms +step:435/1670 train_time:42301ms step_avg:97.24ms +step:436/1670 train_time:42402ms step_avg:97.25ms +step:437/1670 train_time:42501ms step_avg:97.26ms +step:438/1670 train_time:42597ms step_avg:97.25ms +step:439/1670 train_time:42693ms step_avg:97.25ms +step:440/1670 train_time:42788ms step_avg:97.24ms +step:441/1670 train_time:42882ms step_avg:97.24ms +step:442/1670 train_time:42978ms step_avg:97.23ms +step:443/1670 train_time:43072ms step_avg:97.23ms +step:444/1670 train_time:43167ms step_avg:97.22ms +step:445/1670 train_time:43262ms step_avg:97.22ms +step:446/1670 train_time:43358ms step_avg:97.22ms +step:447/1670 train_time:43456ms step_avg:97.22ms +step:448/1670 train_time:43552ms step_avg:97.21ms +step:449/1670 train_time:43648ms step_avg:97.21ms +step:450/1670 train_time:43744ms step_avg:97.21ms +step:451/1670 train_time:43839ms step_avg:97.20ms +step:452/1670 train_time:43934ms step_avg:97.20ms +step:453/1670 train_time:44029ms step_avg:97.19ms +step:454/1670 train_time:44123ms step_avg:97.19ms +step:455/1670 train_time:44219ms step_avg:97.18ms +step:456/1670 train_time:44315ms step_avg:97.18ms +step:457/1670 train_time:44411ms step_avg:97.18ms +step:458/1670 train_time:44508ms step_avg:97.18ms +step:459/1670 train_time:44604ms step_avg:97.18ms +step:460/1670 train_time:44700ms step_avg:97.17ms +step:461/1670 train_time:44795ms step_avg:97.17ms +step:462/1670 train_time:44890ms step_avg:97.17ms +step:463/1670 train_time:44985ms step_avg:97.16ms +step:464/1670 train_time:45081ms step_avg:97.16ms +step:465/1670 train_time:45176ms step_avg:97.15ms +step:466/1670 train_time:45271ms step_avg:97.15ms +step:467/1670 train_time:45367ms step_avg:97.15ms +step:468/1670 train_time:45464ms step_avg:97.14ms +step:469/1670 train_time:45561ms step_avg:97.14ms +step:470/1670 train_time:45657ms step_avg:97.14ms +step:471/1670 train_time:45753ms step_avg:97.14ms +step:472/1670 train_time:45849ms step_avg:97.14ms +step:473/1670 train_time:45943ms step_avg:97.13ms +step:474/1670 train_time:46039ms step_avg:97.13ms +step:475/1670 train_time:46134ms step_avg:97.12ms +step:476/1670 train_time:46229ms step_avg:97.12ms +step:477/1670 train_time:46324ms step_avg:97.12ms +step:478/1670 train_time:46420ms step_avg:97.11ms +step:479/1670 train_time:46517ms step_avg:97.11ms +step:480/1670 train_time:46614ms step_avg:97.11ms +step:481/1670 train_time:46710ms step_avg:97.11ms +step:482/1670 train_time:46805ms step_avg:97.11ms +step:483/1670 train_time:46901ms step_avg:97.10ms +step:484/1670 train_time:46997ms step_avg:97.10ms +step:485/1670 train_time:47092ms step_avg:97.10ms +step:486/1670 train_time:47187ms step_avg:97.09ms +step:487/1670 train_time:47283ms step_avg:97.09ms +step:488/1670 train_time:47378ms step_avg:97.09ms +step:489/1670 train_time:47474ms step_avg:97.08ms +step:490/1670 train_time:47569ms step_avg:97.08ms +step:491/1670 train_time:47664ms step_avg:97.08ms +step:492/1670 train_time:47761ms step_avg:97.07ms +step:493/1670 train_time:47857ms step_avg:97.07ms +step:494/1670 train_time:47952ms step_avg:97.07ms +step:495/1670 train_time:48048ms step_avg:97.07ms +step:496/1670 train_time:48143ms step_avg:97.06ms +step:497/1670 train_time:48238ms step_avg:97.06ms +step:498/1670 train_time:48334ms step_avg:97.06ms +step:499/1670 train_time:48429ms step_avg:97.05ms +step:500/1670 train_time:48525ms step_avg:97.05ms +step:500/1670 val_loss:3.7125 train_time:48621ms step_avg:97.24ms +step:501/1670 train_time:48642ms step_avg:97.09ms +step:502/1670 train_time:48723ms step_avg:97.06ms +step:503/1670 train_time:48827ms step_avg:97.07ms +step:504/1670 train_time:48922ms step_avg:97.07ms +step:505/1670 train_time:49016ms step_avg:97.06ms +step:506/1670 train_time:49111ms step_avg:97.06ms +step:507/1670 train_time:49206ms step_avg:97.05ms +step:508/1670 train_time:49300ms step_avg:97.05ms +step:509/1670 train_time:49395ms step_avg:97.04ms +step:510/1670 train_time:49489ms step_avg:97.04ms +step:511/1670 train_time:49585ms step_avg:97.03ms +step:512/1670 train_time:49681ms step_avg:97.03ms +step:513/1670 train_time:49780ms step_avg:97.04ms +step:514/1670 train_time:49877ms step_avg:97.04ms +step:515/1670 train_time:49973ms step_avg:97.04ms +step:516/1670 train_time:50069ms step_avg:97.03ms +step:517/1670 train_time:50164ms step_avg:97.03ms +step:518/1670 train_time:50258ms step_avg:97.02ms +step:519/1670 train_time:50353ms step_avg:97.02ms +step:520/1670 train_time:50448ms step_avg:97.01ms +step:521/1670 train_time:50542ms step_avg:97.01ms +step:522/1670 train_time:50638ms step_avg:97.01ms +step:523/1670 train_time:50736ms step_avg:97.01ms +step:524/1670 train_time:50834ms step_avg:97.01ms +step:525/1670 train_time:50931ms step_avg:97.01ms +step:526/1670 train_time:51027ms step_avg:97.01ms +step:527/1670 train_time:51122ms step_avg:97.01ms +step:528/1670 train_time:51217ms step_avg:97.00ms +step:529/1670 train_time:51312ms step_avg:97.00ms +step:530/1670 train_time:51407ms step_avg:96.99ms +step:531/1670 train_time:51502ms step_avg:96.99ms +step:532/1670 train_time:51597ms step_avg:96.99ms +step:533/1670 train_time:51694ms step_avg:96.99ms +step:534/1670 train_time:51791ms step_avg:96.99ms +step:535/1670 train_time:51887ms step_avg:96.99ms +step:536/1670 train_time:51983ms step_avg:96.98ms +step:537/1670 train_time:52079ms step_avg:96.98ms +step:538/1670 train_time:52174ms step_avg:96.98ms +step:539/1670 train_time:52270ms step_avg:96.98ms +step:540/1670 train_time:52366ms step_avg:96.97ms +step:541/1670 train_time:52460ms step_avg:96.97ms +step:542/1670 train_time:52556ms step_avg:96.97ms +step:543/1670 train_time:52651ms step_avg:96.96ms +step:544/1670 train_time:52747ms step_avg:96.96ms +step:545/1670 train_time:52843ms step_avg:96.96ms +step:546/1670 train_time:52939ms step_avg:96.96ms +step:547/1670 train_time:53035ms step_avg:96.96ms +step:548/1670 train_time:53131ms step_avg:96.95ms +step:549/1670 train_time:53226ms step_avg:96.95ms +step:550/1670 train_time:53321ms step_avg:96.95ms +step:551/1670 train_time:53416ms step_avg:96.94ms +step:552/1670 train_time:53512ms step_avg:96.94ms +step:553/1670 train_time:53608ms step_avg:96.94ms +step:554/1670 train_time:53703ms step_avg:96.94ms +step:555/1670 train_time:53799ms step_avg:96.93ms +step:556/1670 train_time:53895ms step_avg:96.93ms +step:557/1670 train_time:53991ms step_avg:96.93ms +step:558/1670 train_time:54087ms step_avg:96.93ms +step:559/1670 train_time:54183ms step_avg:96.93ms +step:560/1670 train_time:54280ms step_avg:96.93ms +step:561/1670 train_time:54376ms step_avg:96.93ms +step:562/1670 train_time:54474ms step_avg:96.93ms +step:563/1670 train_time:54571ms step_avg:96.93ms +step:564/1670 train_time:54668ms step_avg:96.93ms +step:565/1670 train_time:54764ms step_avg:96.93ms +step:566/1670 train_time:54861ms step_avg:96.93ms +step:567/1670 train_time:54958ms step_avg:96.93ms +step:568/1670 train_time:55056ms step_avg:96.93ms +step:569/1670 train_time:55154ms step_avg:96.93ms +step:570/1670 train_time:55253ms step_avg:96.93ms +step:571/1670 train_time:55350ms step_avg:96.94ms +step:572/1670 train_time:55447ms step_avg:96.94ms +step:573/1670 train_time:55543ms step_avg:96.93ms +step:574/1670 train_time:55639ms step_avg:96.93ms +step:575/1670 train_time:55738ms step_avg:96.93ms +step:576/1670 train_time:55836ms step_avg:96.94ms +step:577/1670 train_time:55933ms step_avg:96.94ms +step:578/1670 train_time:56030ms step_avg:96.94ms +step:579/1670 train_time:56128ms step_avg:96.94ms +step:580/1670 train_time:56225ms step_avg:96.94ms +step:581/1670 train_time:56322ms step_avg:96.94ms +step:582/1670 train_time:56418ms step_avg:96.94ms +step:583/1670 train_time:56515ms step_avg:96.94ms +step:584/1670 train_time:56612ms step_avg:96.94ms +step:585/1670 train_time:56710ms step_avg:96.94ms +step:586/1670 train_time:56808ms step_avg:96.94ms +step:587/1670 train_time:56905ms step_avg:96.94ms +step:588/1670 train_time:57001ms step_avg:96.94ms +step:589/1670 train_time:57098ms step_avg:96.94ms +step:590/1670 train_time:57195ms step_avg:96.94ms +step:591/1670 train_time:57293ms step_avg:96.94ms +step:592/1670 train_time:57391ms step_avg:96.94ms +step:593/1670 train_time:57489ms step_avg:96.95ms +step:594/1670 train_time:57586ms step_avg:96.95ms +step:595/1670 train_time:57682ms step_avg:96.94ms +step:596/1670 train_time:57779ms step_avg:96.94ms +step:597/1670 train_time:57876ms step_avg:96.95ms +step:598/1670 train_time:57974ms step_avg:96.95ms +step:599/1670 train_time:58072ms step_avg:96.95ms +step:600/1670 train_time:58169ms step_avg:96.95ms +step:601/1670 train_time:58266ms step_avg:96.95ms +step:602/1670 train_time:58362ms step_avg:96.95ms +step:603/1670 train_time:58459ms step_avg:96.95ms +step:604/1670 train_time:58556ms step_avg:96.95ms +step:605/1670 train_time:58654ms step_avg:96.95ms +step:606/1670 train_time:58752ms step_avg:96.95ms +step:607/1670 train_time:58850ms step_avg:96.95ms +step:608/1670 train_time:58947ms step_avg:96.95ms +step:609/1670 train_time:59043ms step_avg:96.95ms +step:610/1670 train_time:59140ms step_avg:96.95ms +step:611/1670 train_time:59237ms step_avg:96.95ms +step:612/1670 train_time:59335ms step_avg:96.95ms +step:613/1670 train_time:59432ms step_avg:96.95ms +step:614/1670 train_time:59530ms step_avg:96.95ms +step:615/1670 train_time:59627ms step_avg:96.95ms +step:616/1670 train_time:59723ms step_avg:96.95ms +step:617/1670 train_time:59820ms step_avg:96.95ms +step:618/1670 train_time:59918ms step_avg:96.95ms +step:619/1670 train_time:60016ms step_avg:96.96ms +step:620/1670 train_time:60114ms step_avg:96.96ms +step:621/1670 train_time:60212ms step_avg:96.96ms +step:622/1670 train_time:60309ms step_avg:96.96ms +step:623/1670 train_time:60405ms step_avg:96.96ms +step:624/1670 train_time:60502ms step_avg:96.96ms +step:625/1670 train_time:60598ms step_avg:96.96ms +step:625/1670 val_loss:3.6145 train_time:60696ms step_avg:97.11ms +step:626/1670 train_time:60718ms step_avg:96.99ms +step:627/1670 train_time:60807ms step_avg:96.98ms +step:628/1670 train_time:60904ms step_avg:96.98ms +step:629/1670 train_time:61000ms step_avg:96.98ms +step:630/1670 train_time:61096ms step_avg:96.98ms +step:631/1670 train_time:61192ms step_avg:96.98ms +step:632/1670 train_time:61288ms step_avg:96.97ms +step:633/1670 train_time:61383ms step_avg:96.97ms +step:634/1670 train_time:61478ms step_avg:96.97ms +step:635/1670 train_time:61575ms step_avg:96.97ms +step:636/1670 train_time:61674ms step_avg:96.97ms +step:637/1670 train_time:61778ms step_avg:96.98ms +step:638/1670 train_time:61878ms step_avg:96.99ms +step:639/1670 train_time:62212ms step_avg:97.36ms +step:640/1670 train_time:62370ms step_avg:97.45ms +step:641/1670 train_time:62465ms step_avg:97.45ms +step:642/1670 train_time:62561ms step_avg:97.45ms +step:643/1670 train_time:62657ms step_avg:97.44ms +step:644/1670 train_time:62753ms step_avg:97.44ms +step:645/1670 train_time:62848ms step_avg:97.44ms +step:646/1670 train_time:62944ms step_avg:97.44ms +step:647/1670 train_time:63040ms step_avg:97.43ms +step:648/1670 train_time:63137ms step_avg:97.43ms +step:649/1670 train_time:63236ms step_avg:97.44ms +step:650/1670 train_time:63337ms step_avg:97.44ms +step:651/1670 train_time:63438ms step_avg:97.45ms +step:652/1670 train_time:63536ms step_avg:97.45ms +step:653/1670 train_time:63633ms step_avg:97.45ms +step:654/1670 train_time:63729ms step_avg:97.45ms +step:655/1670 train_time:63825ms step_avg:97.44ms +step:656/1670 train_time:63921ms step_avg:97.44ms +step:657/1670 train_time:64017ms step_avg:97.44ms +step:658/1670 train_time:64114ms step_avg:97.44ms +step:659/1670 train_time:64212ms step_avg:97.44ms +step:660/1670 train_time:64311ms step_avg:97.44ms +step:661/1670 train_time:64409ms step_avg:97.44ms +step:662/1670 train_time:64506ms step_avg:97.44ms +step:663/1670 train_time:64603ms step_avg:97.44ms +step:664/1670 train_time:64700ms step_avg:97.44ms +step:665/1670 train_time:64797ms step_avg:97.44ms +step:666/1670 train_time:64894ms step_avg:97.44ms +step:667/1670 train_time:64990ms step_avg:97.44ms +step:668/1670 train_time:65087ms step_avg:97.44ms +step:669/1670 train_time:65184ms step_avg:97.43ms +step:670/1670 train_time:65280ms step_avg:97.43ms +step:671/1670 train_time:65378ms step_avg:97.43ms +step:672/1670 train_time:65478ms step_avg:97.44ms +step:673/1670 train_time:65576ms step_avg:97.44ms +step:674/1670 train_time:65673ms step_avg:97.44ms +step:675/1670 train_time:65770ms step_avg:97.44ms +step:676/1670 train_time:65866ms step_avg:97.44ms +step:677/1670 train_time:65962ms step_avg:97.43ms +step:678/1670 train_time:66059ms step_avg:97.43ms +step:679/1670 train_time:66156ms step_avg:97.43ms +step:680/1670 train_time:66253ms step_avg:97.43ms +step:681/1670 train_time:66350ms step_avg:97.43ms +step:682/1670 train_time:66447ms step_avg:97.43ms +step:683/1670 train_time:66543ms step_avg:97.43ms +step:684/1670 train_time:66640ms step_avg:97.43ms +step:685/1670 train_time:66738ms step_avg:97.43ms +step:686/1670 train_time:66835ms step_avg:97.43ms +step:687/1670 train_time:66932ms step_avg:97.43ms +step:688/1670 train_time:67029ms step_avg:97.43ms +step:689/1670 train_time:67125ms step_avg:97.42ms +step:690/1670 train_time:67221ms step_avg:97.42ms +step:691/1670 train_time:67319ms step_avg:97.42ms +step:692/1670 train_time:67416ms step_avg:97.42ms +step:693/1670 train_time:67514ms step_avg:97.42ms +step:694/1670 train_time:67612ms step_avg:97.42ms +step:695/1670 train_time:67709ms step_avg:97.42ms +step:696/1670 train_time:67806ms step_avg:97.42ms +step:697/1670 train_time:67903ms step_avg:97.42ms +step:698/1670 train_time:68000ms step_avg:97.42ms +step:699/1670 train_time:68098ms step_avg:97.42ms +step:700/1670 train_time:68196ms step_avg:97.42ms +step:701/1670 train_time:68292ms step_avg:97.42ms +step:702/1670 train_time:68389ms step_avg:97.42ms +step:703/1670 train_time:68488ms step_avg:97.42ms +step:704/1670 train_time:68584ms step_avg:97.42ms +step:705/1670 train_time:68681ms step_avg:97.42ms +step:706/1670 train_time:68778ms step_avg:97.42ms +step:707/1670 train_time:68875ms step_avg:97.42ms +step:708/1670 train_time:68972ms step_avg:97.42ms +step:709/1670 train_time:69069ms step_avg:97.42ms +step:710/1670 train_time:69167ms step_avg:97.42ms +step:711/1670 train_time:69265ms step_avg:97.42ms +step:712/1670 train_time:69361ms step_avg:97.42ms +step:713/1670 train_time:69459ms step_avg:97.42ms +step:714/1670 train_time:69556ms step_avg:97.42ms +step:715/1670 train_time:69654ms step_avg:97.42ms +step:716/1670 train_time:69751ms step_avg:97.42ms +step:717/1670 train_time:69847ms step_avg:97.42ms +step:718/1670 train_time:69943ms step_avg:97.41ms +step:719/1670 train_time:70040ms step_avg:97.41ms +step:720/1670 train_time:70138ms step_avg:97.41ms +step:721/1670 train_time:70235ms step_avg:97.41ms +step:722/1670 train_time:70333ms step_avg:97.41ms +step:723/1670 train_time:70431ms step_avg:97.42ms +step:724/1670 train_time:70528ms step_avg:97.42ms +step:725/1670 train_time:70626ms step_avg:97.41ms +step:726/1670 train_time:70722ms step_avg:97.41ms +step:727/1670 train_time:70820ms step_avg:97.41ms +step:728/1670 train_time:70918ms step_avg:97.41ms +step:729/1670 train_time:71016ms step_avg:97.42ms +step:730/1670 train_time:71113ms step_avg:97.41ms +step:731/1670 train_time:71211ms step_avg:97.42ms +step:732/1670 train_time:71307ms step_avg:97.41ms +step:733/1670 train_time:71404ms step_avg:97.41ms +step:734/1670 train_time:71502ms step_avg:97.41ms +step:735/1670 train_time:71600ms step_avg:97.41ms +step:736/1670 train_time:71697ms step_avg:97.41ms +step:737/1670 train_time:71795ms step_avg:97.41ms +step:738/1670 train_time:71891ms step_avg:97.41ms +step:739/1670 train_time:71989ms step_avg:97.41ms +step:740/1670 train_time:72085ms step_avg:97.41ms +step:741/1670 train_time:72181ms step_avg:97.41ms +step:742/1670 train_time:72279ms step_avg:97.41ms +step:743/1670 train_time:72377ms step_avg:97.41ms +step:744/1670 train_time:72474ms step_avg:97.41ms +step:745/1670 train_time:72572ms step_avg:97.41ms +step:746/1670 train_time:72668ms step_avg:97.41ms +step:747/1670 train_time:72764ms step_avg:97.41ms +step:748/1670 train_time:72860ms step_avg:97.41ms +step:749/1670 train_time:72958ms step_avg:97.41ms +step:750/1670 train_time:73057ms step_avg:97.41ms +step:750/1670 val_loss:3.5614 train_time:73154ms step_avg:97.54ms +step:751/1670 train_time:73175ms step_avg:97.44ms +step:752/1670 train_time:73257ms step_avg:97.42ms +step:753/1670 train_time:73358ms step_avg:97.42ms +step:754/1670 train_time:73456ms step_avg:97.42ms +step:755/1670 train_time:73553ms step_avg:97.42ms +step:756/1670 train_time:73650ms step_avg:97.42ms +step:757/1670 train_time:73746ms step_avg:97.42ms +step:758/1670 train_time:73842ms step_avg:97.42ms +step:759/1670 train_time:73938ms step_avg:97.42ms +step:760/1670 train_time:74035ms step_avg:97.41ms +step:761/1670 train_time:74133ms step_avg:97.41ms +step:762/1670 train_time:74232ms step_avg:97.42ms +step:763/1670 train_time:74332ms step_avg:97.42ms +step:764/1670 train_time:74429ms step_avg:97.42ms +step:765/1670 train_time:74526ms step_avg:97.42ms +step:766/1670 train_time:74623ms step_avg:97.42ms +step:767/1670 train_time:74719ms step_avg:97.42ms +step:768/1670 train_time:74816ms step_avg:97.42ms +step:769/1670 train_time:74913ms step_avg:97.42ms +step:770/1670 train_time:75009ms step_avg:97.41ms +step:771/1670 train_time:75106ms step_avg:97.41ms +step:772/1670 train_time:75204ms step_avg:97.41ms +step:773/1670 train_time:75301ms step_avg:97.41ms +step:774/1670 train_time:75398ms step_avg:97.41ms +step:775/1670 train_time:75497ms step_avg:97.42ms +step:776/1670 train_time:75595ms step_avg:97.42ms +step:777/1670 train_time:75693ms step_avg:97.42ms +step:778/1670 train_time:75790ms step_avg:97.42ms +step:779/1670 train_time:75886ms step_avg:97.41ms +step:780/1670 train_time:75982ms step_avg:97.41ms +step:781/1670 train_time:76078ms step_avg:97.41ms +step:782/1670 train_time:76177ms step_avg:97.41ms +step:783/1670 train_time:76275ms step_avg:97.41ms +step:784/1670 train_time:76372ms step_avg:97.41ms +step:785/1670 train_time:76470ms step_avg:97.41ms +step:786/1670 train_time:76567ms step_avg:97.41ms +step:787/1670 train_time:76664ms step_avg:97.41ms +step:788/1670 train_time:76761ms step_avg:97.41ms +step:789/1670 train_time:76858ms step_avg:97.41ms +step:790/1670 train_time:76955ms step_avg:97.41ms +step:791/1670 train_time:77052ms step_avg:97.41ms +step:792/1670 train_time:77149ms step_avg:97.41ms +step:793/1670 train_time:77246ms step_avg:97.41ms +step:794/1670 train_time:77343ms step_avg:97.41ms +step:795/1670 train_time:77440ms step_avg:97.41ms +step:796/1670 train_time:77539ms step_avg:97.41ms +step:797/1670 train_time:77637ms step_avg:97.41ms +step:798/1670 train_time:77734ms step_avg:97.41ms +step:799/1670 train_time:77831ms step_avg:97.41ms +step:800/1670 train_time:77929ms step_avg:97.41ms +step:801/1670 train_time:78026ms step_avg:97.41ms +step:802/1670 train_time:78122ms step_avg:97.41ms +step:803/1670 train_time:78219ms step_avg:97.41ms +step:804/1670 train_time:78316ms step_avg:97.41ms +step:805/1670 train_time:78413ms step_avg:97.41ms +step:806/1670 train_time:78510ms step_avg:97.41ms +step:807/1670 train_time:78608ms step_avg:97.41ms +step:808/1670 train_time:78704ms step_avg:97.41ms +step:809/1670 train_time:78801ms step_avg:97.40ms +step:810/1670 train_time:78898ms step_avg:97.40ms +step:811/1670 train_time:78995ms step_avg:97.40ms +step:812/1670 train_time:79093ms step_avg:97.40ms +step:813/1670 train_time:79190ms step_avg:97.40ms +step:814/1670 train_time:79287ms step_avg:97.40ms +step:815/1670 train_time:79384ms step_avg:97.40ms +step:816/1670 train_time:79480ms step_avg:97.40ms +step:817/1670 train_time:79578ms step_avg:97.40ms +step:818/1670 train_time:79677ms step_avg:97.40ms +step:819/1670 train_time:79775ms step_avg:97.41ms +step:820/1670 train_time:79872ms step_avg:97.41ms +step:821/1670 train_time:79970ms step_avg:97.41ms +step:822/1670 train_time:80067ms step_avg:97.40ms +step:823/1670 train_time:80163ms step_avg:97.40ms +step:824/1670 train_time:80259ms step_avg:97.40ms +step:825/1670 train_time:80356ms step_avg:97.40ms +step:826/1670 train_time:80454ms step_avg:97.40ms +step:827/1670 train_time:80551ms step_avg:97.40ms +step:828/1670 train_time:80649ms step_avg:97.40ms +step:829/1670 train_time:80747ms step_avg:97.40ms +step:830/1670 train_time:80843ms step_avg:97.40ms +step:831/1670 train_time:80940ms step_avg:97.40ms +step:832/1670 train_time:81038ms step_avg:97.40ms +step:833/1670 train_time:81136ms step_avg:97.40ms +step:834/1670 train_time:81233ms step_avg:97.40ms +step:835/1670 train_time:81330ms step_avg:97.40ms +step:836/1670 train_time:81427ms step_avg:97.40ms +step:837/1670 train_time:81524ms step_avg:97.40ms +step:838/1670 train_time:81621ms step_avg:97.40ms +step:839/1670 train_time:81718ms step_avg:97.40ms +step:840/1670 train_time:81817ms step_avg:97.40ms +step:841/1670 train_time:81914ms step_avg:97.40ms +step:842/1670 train_time:82011ms step_avg:97.40ms +step:843/1670 train_time:82109ms step_avg:97.40ms +step:844/1670 train_time:82205ms step_avg:97.40ms +step:845/1670 train_time:82301ms step_avg:97.40ms +step:846/1670 train_time:82398ms step_avg:97.40ms +step:847/1670 train_time:82496ms step_avg:97.40ms +step:848/1670 train_time:82594ms step_avg:97.40ms +step:849/1670 train_time:82690ms step_avg:97.40ms +step:850/1670 train_time:82787ms step_avg:97.40ms +step:851/1670 train_time:83055ms step_avg:97.60ms +step:852/1670 train_time:83197ms step_avg:97.65ms +step:853/1670 train_time:83292ms step_avg:97.65ms +step:854/1670 train_time:83388ms step_avg:97.64ms +step:855/1670 train_time:83483ms step_avg:97.64ms +step:856/1670 train_time:83579ms step_avg:97.64ms +step:857/1670 train_time:83674ms step_avg:97.64ms +step:858/1670 train_time:83770ms step_avg:97.63ms +step:859/1670 train_time:83865ms step_avg:97.63ms +step:860/1670 train_time:83961ms step_avg:97.63ms +step:861/1670 train_time:84058ms step_avg:97.63ms +step:862/1670 train_time:84161ms step_avg:97.63ms +step:863/1670 train_time:84262ms step_avg:97.64ms +step:864/1670 train_time:84358ms step_avg:97.64ms +step:865/1670 train_time:84455ms step_avg:97.64ms +step:866/1670 train_time:84551ms step_avg:97.63ms +step:867/1670 train_time:84647ms step_avg:97.63ms +step:868/1670 train_time:84743ms step_avg:97.63ms +step:869/1670 train_time:84838ms step_avg:97.63ms +step:870/1670 train_time:84935ms step_avg:97.63ms +step:871/1670 train_time:85032ms step_avg:97.63ms +step:872/1670 train_time:85131ms step_avg:97.63ms +step:873/1670 train_time:85231ms step_avg:97.63ms +step:874/1670 train_time:85329ms step_avg:97.63ms +step:875/1670 train_time:85425ms step_avg:97.63ms +step:875/1670 val_loss:3.5194 train_time:85521ms step_avg:97.74ms +step:876/1670 train_time:85542ms step_avg:97.65ms +step:877/1670 train_time:85623ms step_avg:97.63ms +step:878/1670 train_time:85720ms step_avg:97.63ms +step:879/1670 train_time:85817ms step_avg:97.63ms +step:880/1670 train_time:85913ms step_avg:97.63ms +step:881/1670 train_time:86009ms step_avg:97.63ms +step:882/1670 train_time:86104ms step_avg:97.62ms +step:883/1670 train_time:86200ms step_avg:97.62ms +step:884/1670 train_time:86296ms step_avg:97.62ms +step:885/1670 train_time:86392ms step_avg:97.62ms +step:886/1670 train_time:86493ms step_avg:97.62ms +step:887/1670 train_time:86593ms step_avg:97.62ms +step:888/1670 train_time:86694ms step_avg:97.63ms +step:889/1670 train_time:86791ms step_avg:97.63ms +step:890/1670 train_time:86889ms step_avg:97.63ms +step:891/1670 train_time:86986ms step_avg:97.63ms +step:892/1670 train_time:87081ms step_avg:97.62ms +step:893/1670 train_time:87177ms step_avg:97.62ms +step:894/1670 train_time:87273ms step_avg:97.62ms +step:895/1670 train_time:87370ms step_avg:97.62ms +step:896/1670 train_time:87468ms step_avg:97.62ms +step:897/1670 train_time:87567ms step_avg:97.62ms +step:898/1670 train_time:87665ms step_avg:97.62ms +step:899/1670 train_time:87762ms step_avg:97.62ms +step:900/1670 train_time:87859ms step_avg:97.62ms +step:901/1670 train_time:87957ms step_avg:97.62ms +step:902/1670 train_time:88054ms step_avg:97.62ms +step:903/1670 train_time:88151ms step_avg:97.62ms +step:904/1670 train_time:88248ms step_avg:97.62ms +step:905/1670 train_time:88344ms step_avg:97.62ms +step:906/1670 train_time:88441ms step_avg:97.62ms +step:907/1670 train_time:88538ms step_avg:97.62ms +step:908/1670 train_time:88635ms step_avg:97.62ms +step:909/1670 train_time:88733ms step_avg:97.62ms +step:910/1670 train_time:88831ms step_avg:97.62ms +step:911/1670 train_time:88929ms step_avg:97.62ms +step:912/1670 train_time:89026ms step_avg:97.62ms +step:913/1670 train_time:89122ms step_avg:97.61ms +step:914/1670 train_time:89218ms step_avg:97.61ms +step:915/1670 train_time:89315ms step_avg:97.61ms +step:916/1670 train_time:89412ms step_avg:97.61ms +step:917/1670 train_time:89510ms step_avg:97.61ms +step:918/1670 train_time:89607ms step_avg:97.61ms +step:919/1670 train_time:89704ms step_avg:97.61ms +step:920/1670 train_time:89801ms step_avg:97.61ms +step:921/1670 train_time:89898ms step_avg:97.61ms +step:922/1670 train_time:89995ms step_avg:97.61ms +step:923/1670 train_time:90093ms step_avg:97.61ms +step:924/1670 train_time:90190ms step_avg:97.61ms +step:925/1670 train_time:90287ms step_avg:97.61ms +step:926/1670 train_time:90384ms step_avg:97.61ms +step:927/1670 train_time:90480ms step_avg:97.61ms +step:928/1670 train_time:90577ms step_avg:97.61ms +step:929/1670 train_time:90675ms step_avg:97.60ms +step:930/1670 train_time:90773ms step_avg:97.61ms +step:931/1670 train_time:90871ms step_avg:97.61ms +step:932/1670 train_time:90967ms step_avg:97.60ms +step:933/1670 train_time:91064ms step_avg:97.60ms +step:934/1670 train_time:91160ms step_avg:97.60ms +step:935/1670 train_time:91256ms step_avg:97.60ms +step:936/1670 train_time:91356ms step_avg:97.60ms +step:937/1670 train_time:91454ms step_avg:97.60ms +step:938/1670 train_time:91552ms step_avg:97.60ms +step:939/1670 train_time:91649ms step_avg:97.60ms +step:940/1670 train_time:91746ms step_avg:97.60ms +step:941/1670 train_time:91842ms step_avg:97.60ms +step:942/1670 train_time:91940ms step_avg:97.60ms +step:943/1670 train_time:92037ms step_avg:97.60ms +step:944/1670 train_time:92134ms step_avg:97.60ms +step:945/1670 train_time:92231ms step_avg:97.60ms +step:946/1670 train_time:92329ms step_avg:97.60ms +step:947/1670 train_time:92426ms step_avg:97.60ms +step:948/1670 train_time:92522ms step_avg:97.60ms +step:949/1670 train_time:92619ms step_avg:97.60ms +step:950/1670 train_time:92716ms step_avg:97.60ms +step:951/1670 train_time:92814ms step_avg:97.60ms +step:952/1670 train_time:92912ms step_avg:97.60ms +step:953/1670 train_time:93009ms step_avg:97.60ms +step:954/1670 train_time:93107ms step_avg:97.60ms +step:955/1670 train_time:93203ms step_avg:97.60ms +step:956/1670 train_time:93300ms step_avg:97.59ms +step:957/1670 train_time:93396ms step_avg:97.59ms +step:958/1670 train_time:93494ms step_avg:97.59ms +step:959/1670 train_time:93592ms step_avg:97.59ms +step:960/1670 train_time:93691ms step_avg:97.59ms +step:961/1670 train_time:93789ms step_avg:97.59ms +step:962/1670 train_time:93886ms step_avg:97.59ms +step:963/1670 train_time:93982ms step_avg:97.59ms +step:964/1670 train_time:94079ms step_avg:97.59ms +step:965/1670 train_time:94177ms step_avg:97.59ms +step:966/1670 train_time:94273ms step_avg:97.59ms +step:967/1670 train_time:94370ms step_avg:97.59ms +step:968/1670 train_time:94468ms step_avg:97.59ms +step:969/1670 train_time:94565ms step_avg:97.59ms +step:970/1670 train_time:94662ms step_avg:97.59ms +step:971/1670 train_time:94759ms step_avg:97.59ms +step:972/1670 train_time:94857ms step_avg:97.59ms +step:973/1670 train_time:94955ms step_avg:97.59ms +step:974/1670 train_time:95054ms step_avg:97.59ms +step:975/1670 train_time:95151ms step_avg:97.59ms +step:976/1670 train_time:95249ms step_avg:97.59ms +step:977/1670 train_time:95345ms step_avg:97.59ms +step:978/1670 train_time:95442ms step_avg:97.59ms +step:979/1670 train_time:95539ms step_avg:97.59ms +step:980/1670 train_time:95636ms step_avg:97.59ms +step:981/1670 train_time:95734ms step_avg:97.59ms +step:982/1670 train_time:95833ms step_avg:97.59ms +step:983/1670 train_time:95931ms step_avg:97.59ms +step:984/1670 train_time:96029ms step_avg:97.59ms +step:985/1670 train_time:96126ms step_avg:97.59ms +step:986/1670 train_time:96222ms step_avg:97.59ms +step:987/1670 train_time:96319ms step_avg:97.59ms +step:988/1670 train_time:96416ms step_avg:97.59ms +step:989/1670 train_time:96514ms step_avg:97.59ms +step:990/1670 train_time:96611ms step_avg:97.59ms +step:991/1670 train_time:96708ms step_avg:97.59ms +step:992/1670 train_time:96805ms step_avg:97.59ms +step:993/1670 train_time:96901ms step_avg:97.58ms +step:994/1670 train_time:96998ms step_avg:97.58ms +step:995/1670 train_time:97096ms step_avg:97.58ms +step:996/1670 train_time:97195ms step_avg:97.59ms +step:997/1670 train_time:97292ms step_avg:97.58ms +step:998/1670 train_time:97390ms step_avg:97.59ms +step:999/1670 train_time:97488ms step_avg:97.59ms +step:1000/1670 train_time:97585ms step_avg:97.58ms +step:1000/1670 val_loss:3.4778 train_time:97680ms step_avg:97.68ms +step:1001/1670 train_time:97701ms step_avg:97.60ms +step:1002/1670 train_time:97783ms step_avg:97.59ms +step:1003/1670 train_time:97882ms step_avg:97.59ms +step:1004/1670 train_time:97980ms step_avg:97.59ms +step:1005/1670 train_time:98076ms step_avg:97.59ms +step:1006/1670 train_time:98172ms step_avg:97.59ms +step:1007/1670 train_time:98268ms step_avg:97.58ms +step:1008/1670 train_time:98364ms step_avg:97.58ms +step:1009/1670 train_time:98462ms step_avg:97.58ms +step:1010/1670 train_time:98559ms step_avg:97.58ms +step:1011/1670 train_time:98658ms step_avg:97.58ms +step:1012/1670 train_time:98757ms step_avg:97.59ms +step:1013/1670 train_time:98856ms step_avg:97.59ms +step:1014/1670 train_time:98954ms step_avg:97.59ms +step:1015/1670 train_time:99050ms step_avg:97.59ms +step:1016/1670 train_time:99146ms step_avg:97.58ms +step:1017/1670 train_time:99242ms step_avg:97.58ms +step:1018/1670 train_time:99339ms step_avg:97.58ms +step:1019/1670 train_time:99435ms step_avg:97.58ms +step:1020/1670 train_time:99531ms step_avg:97.58ms +step:1021/1670 train_time:99628ms step_avg:97.58ms +step:1022/1670 train_time:99725ms step_avg:97.58ms +step:1023/1670 train_time:99823ms step_avg:97.58ms +step:1024/1670 train_time:99922ms step_avg:97.58ms +step:1025/1670 train_time:100020ms step_avg:97.58ms +step:1026/1670 train_time:100117ms step_avg:97.58ms +step:1027/1670 train_time:100214ms step_avg:97.58ms +step:1028/1670 train_time:100311ms step_avg:97.58ms +step:1029/1670 train_time:100406ms step_avg:97.58ms +step:1030/1670 train_time:100503ms step_avg:97.58ms +step:1031/1670 train_time:100601ms step_avg:97.58ms +step:1032/1670 train_time:100698ms step_avg:97.58ms +step:1033/1670 train_time:100796ms step_avg:97.58ms +step:1034/1670 train_time:100894ms step_avg:97.58ms +step:1035/1670 train_time:100991ms step_avg:97.58ms +step:1036/1670 train_time:101088ms step_avg:97.58ms +step:1037/1670 train_time:101186ms step_avg:97.58ms +step:1038/1670 train_time:101282ms step_avg:97.57ms +step:1039/1670 train_time:101380ms step_avg:97.57ms +step:1040/1670 train_time:101477ms step_avg:97.57ms +step:1041/1670 train_time:101574ms step_avg:97.57ms +step:1042/1670 train_time:101671ms step_avg:97.57ms +step:1043/1670 train_time:101767ms step_avg:97.57ms +step:1044/1670 train_time:101865ms step_avg:97.57ms +step:1045/1670 train_time:101964ms step_avg:97.57ms +step:1046/1670 train_time:102062ms step_avg:97.57ms +step:1047/1670 train_time:102160ms step_avg:97.57ms +step:1048/1670 train_time:102257ms step_avg:97.57ms +step:1049/1670 train_time:102353ms step_avg:97.57ms +step:1050/1670 train_time:102451ms step_avg:97.57ms +step:1051/1670 train_time:102547ms step_avg:97.57ms +step:1052/1670 train_time:102643ms step_avg:97.57ms +step:1053/1670 train_time:102741ms step_avg:97.57ms +step:1054/1670 train_time:102838ms step_avg:97.57ms +step:1055/1670 train_time:102936ms step_avg:97.57ms +step:1056/1670 train_time:103034ms step_avg:97.57ms +step:1057/1670 train_time:103132ms step_avg:97.57ms +step:1058/1670 train_time:103228ms step_avg:97.57ms +step:1059/1670 train_time:103325ms step_avg:97.57ms +step:1060/1670 train_time:103423ms step_avg:97.57ms +step:1061/1670 train_time:103520ms step_avg:97.57ms +step:1062/1670 train_time:103794ms step_avg:97.73ms +step:1063/1670 train_time:103894ms step_avg:97.74ms +step:1064/1670 train_time:103989ms step_avg:97.73ms +step:1065/1670 train_time:104086ms step_avg:97.73ms +step:1066/1670 train_time:104182ms step_avg:97.73ms +step:1067/1670 train_time:104278ms step_avg:97.73ms +step:1068/1670 train_time:104374ms step_avg:97.73ms +step:1069/1670 train_time:104470ms step_avg:97.73ms +step:1070/1670 train_time:104566ms step_avg:97.73ms +step:1071/1670 train_time:104662ms step_avg:97.72ms +step:1072/1670 train_time:104765ms step_avg:97.73ms +step:1073/1670 train_time:104865ms step_avg:97.73ms +step:1074/1670 train_time:104964ms step_avg:97.73ms +step:1075/1670 train_time:105062ms step_avg:97.73ms +step:1076/1670 train_time:105160ms step_avg:97.73ms +step:1077/1670 train_time:105256ms step_avg:97.73ms +step:1078/1670 train_time:105353ms step_avg:97.73ms +step:1079/1670 train_time:105449ms step_avg:97.73ms +step:1080/1670 train_time:105545ms step_avg:97.73ms +step:1081/1670 train_time:105643ms step_avg:97.73ms +step:1082/1670 train_time:105743ms step_avg:97.73ms +step:1083/1670 train_time:105843ms step_avg:97.73ms +step:1084/1670 train_time:105941ms step_avg:97.73ms +step:1085/1670 train_time:106039ms step_avg:97.73ms +step:1086/1670 train_time:106135ms step_avg:97.73ms +step:1087/1670 train_time:106231ms step_avg:97.73ms +step:1088/1670 train_time:106327ms step_avg:97.73ms +step:1089/1670 train_time:106423ms step_avg:97.73ms +step:1090/1670 train_time:106520ms step_avg:97.73ms +step:1091/1670 train_time:106619ms step_avg:97.73ms +step:1092/1670 train_time:106718ms step_avg:97.73ms +step:1093/1670 train_time:106816ms step_avg:97.73ms +step:1094/1670 train_time:106915ms step_avg:97.73ms +step:1095/1670 train_time:107014ms step_avg:97.73ms +step:1096/1670 train_time:107110ms step_avg:97.73ms +step:1097/1670 train_time:107206ms step_avg:97.73ms +step:1098/1670 train_time:107302ms step_avg:97.73ms +step:1099/1670 train_time:107399ms step_avg:97.72ms +step:1100/1670 train_time:107495ms step_avg:97.72ms +step:1101/1670 train_time:107592ms step_avg:97.72ms +step:1102/1670 train_time:107689ms step_avg:97.72ms +step:1103/1670 train_time:107786ms step_avg:97.72ms +step:1104/1670 train_time:107884ms step_avg:97.72ms +step:1105/1670 train_time:107982ms step_avg:97.72ms +step:1106/1670 train_time:108080ms step_avg:97.72ms +step:1107/1670 train_time:108178ms step_avg:97.72ms +step:1108/1670 train_time:108274ms step_avg:97.72ms +step:1109/1670 train_time:108371ms step_avg:97.72ms +step:1110/1670 train_time:108466ms step_avg:97.72ms +step:1111/1670 train_time:108563ms step_avg:97.72ms +step:1112/1670 train_time:108661ms step_avg:97.72ms +step:1113/1670 train_time:108759ms step_avg:97.72ms +step:1114/1670 train_time:108857ms step_avg:97.72ms +step:1115/1670 train_time:108955ms step_avg:97.72ms +step:1116/1670 train_time:109054ms step_avg:97.72ms +step:1117/1670 train_time:109151ms step_avg:97.72ms +step:1118/1670 train_time:109249ms step_avg:97.72ms +step:1119/1670 train_time:109346ms step_avg:97.72ms +step:1120/1670 train_time:109443ms step_avg:97.72ms +step:1121/1670 train_time:109542ms step_avg:97.72ms +step:1122/1670 train_time:109640ms step_avg:97.72ms +step:1123/1670 train_time:109738ms step_avg:97.72ms +step:1124/1670 train_time:109837ms step_avg:97.72ms +step:1125/1670 train_time:109935ms step_avg:97.72ms +step:1125/1670 val_loss:3.4238 train_time:110032ms step_avg:97.81ms +step:1126/1670 train_time:110054ms step_avg:97.74ms +step:1127/1670 train_time:110142ms step_avg:97.73ms +step:1128/1670 train_time:110240ms step_avg:97.73ms +step:1129/1670 train_time:110337ms step_avg:97.73ms +step:1130/1670 train_time:110433ms step_avg:97.73ms +step:1131/1670 train_time:110529ms step_avg:97.73ms +step:1132/1670 train_time:110625ms step_avg:97.73ms +step:1133/1670 train_time:110722ms step_avg:97.72ms +step:1134/1670 train_time:110819ms step_avg:97.72ms +step:1135/1670 train_time:110918ms step_avg:97.73ms +step:1136/1670 train_time:111022ms step_avg:97.73ms +step:1137/1670 train_time:111122ms step_avg:97.73ms +step:1138/1670 train_time:111222ms step_avg:97.73ms +step:1139/1670 train_time:111321ms step_avg:97.74ms +step:1140/1670 train_time:111420ms step_avg:97.74ms +step:1141/1670 train_time:111517ms step_avg:97.74ms +step:1142/1670 train_time:111614ms step_avg:97.74ms +step:1143/1670 train_time:111712ms step_avg:97.74ms +step:1144/1670 train_time:111808ms step_avg:97.73ms +step:1145/1670 train_time:111905ms step_avg:97.73ms +step:1146/1670 train_time:112006ms step_avg:97.74ms +step:1147/1670 train_time:112104ms step_avg:97.74ms +step:1148/1670 train_time:112202ms step_avg:97.74ms +step:1149/1670 train_time:112302ms step_avg:97.74ms +step:1150/1670 train_time:112400ms step_avg:97.74ms +step:1151/1670 train_time:112498ms step_avg:97.74ms +step:1152/1670 train_time:112595ms step_avg:97.74ms +step:1153/1670 train_time:112691ms step_avg:97.74ms +step:1154/1670 train_time:112788ms step_avg:97.74ms +step:1155/1670 train_time:112885ms step_avg:97.74ms +step:1156/1670 train_time:112985ms step_avg:97.74ms +step:1157/1670 train_time:113085ms step_avg:97.74ms +step:1158/1670 train_time:113183ms step_avg:97.74ms +step:1159/1670 train_time:113282ms step_avg:97.74ms +step:1160/1670 train_time:113380ms step_avg:97.74ms +step:1161/1670 train_time:113479ms step_avg:97.74ms +step:1162/1670 train_time:113576ms step_avg:97.74ms +step:1163/1670 train_time:113673ms step_avg:97.74ms +step:1164/1670 train_time:113770ms step_avg:97.74ms +step:1165/1670 train_time:113867ms step_avg:97.74ms +step:1166/1670 train_time:113964ms step_avg:97.74ms +step:1167/1670 train_time:114062ms step_avg:97.74ms +step:1168/1670 train_time:114162ms step_avg:97.74ms +step:1169/1670 train_time:114261ms step_avg:97.74ms +step:1170/1670 train_time:114359ms step_avg:97.74ms +step:1171/1670 train_time:114457ms step_avg:97.74ms +step:1172/1670 train_time:114555ms step_avg:97.74ms +step:1173/1670 train_time:114652ms step_avg:97.74ms +step:1174/1670 train_time:114750ms step_avg:97.74ms +step:1175/1670 train_time:114847ms step_avg:97.74ms +step:1176/1670 train_time:114944ms step_avg:97.74ms +step:1177/1670 train_time:115042ms step_avg:97.74ms +step:1178/1670 train_time:115141ms step_avg:97.74ms +step:1179/1670 train_time:115240ms step_avg:97.74ms +step:1180/1670 train_time:115339ms step_avg:97.74ms +step:1181/1670 train_time:115437ms step_avg:97.75ms +step:1182/1670 train_time:115535ms step_avg:97.75ms +step:1183/1670 train_time:115633ms step_avg:97.75ms +step:1184/1670 train_time:115730ms step_avg:97.75ms +step:1185/1670 train_time:115827ms step_avg:97.74ms +step:1186/1670 train_time:115925ms step_avg:97.74ms +step:1187/1670 train_time:116022ms step_avg:97.74ms +step:1188/1670 train_time:116119ms step_avg:97.74ms +step:1189/1670 train_time:116217ms step_avg:97.74ms +step:1190/1670 train_time:116315ms step_avg:97.74ms +step:1191/1670 train_time:116412ms step_avg:97.74ms +step:1192/1670 train_time:116510ms step_avg:97.74ms +step:1193/1670 train_time:116607ms step_avg:97.74ms +step:1194/1670 train_time:116705ms step_avg:97.74ms +step:1195/1670 train_time:116803ms step_avg:97.74ms +step:1196/1670 train_time:116901ms step_avg:97.74ms +step:1197/1670 train_time:117000ms step_avg:97.74ms +step:1198/1670 train_time:117097ms step_avg:97.74ms +step:1199/1670 train_time:117195ms step_avg:97.74ms +step:1200/1670 train_time:117293ms step_avg:97.74ms +step:1201/1670 train_time:117390ms step_avg:97.74ms +step:1202/1670 train_time:117487ms step_avg:97.74ms +step:1203/1670 train_time:117585ms step_avg:97.74ms +step:1204/1670 train_time:117683ms step_avg:97.74ms +step:1205/1670 train_time:117783ms step_avg:97.75ms +step:1206/1670 train_time:117881ms step_avg:97.75ms +step:1207/1670 train_time:117979ms step_avg:97.75ms +step:1208/1670 train_time:118076ms step_avg:97.75ms +step:1209/1670 train_time:118173ms step_avg:97.74ms +step:1210/1670 train_time:118270ms step_avg:97.74ms +step:1211/1670 train_time:118367ms step_avg:97.74ms +step:1212/1670 train_time:118465ms step_avg:97.74ms +step:1213/1670 train_time:118563ms step_avg:97.74ms +step:1214/1670 train_time:118662ms step_avg:97.74ms +step:1215/1670 train_time:118760ms step_avg:97.75ms +step:1216/1670 train_time:118859ms step_avg:97.75ms +step:1217/1670 train_time:118957ms step_avg:97.75ms +step:1218/1670 train_time:119055ms step_avg:97.75ms +step:1219/1670 train_time:119152ms step_avg:97.75ms +step:1220/1670 train_time:119250ms step_avg:97.75ms +step:1221/1670 train_time:119347ms step_avg:97.75ms +step:1222/1670 train_time:119445ms step_avg:97.75ms +step:1223/1670 train_time:119543ms step_avg:97.75ms +step:1224/1670 train_time:119640ms step_avg:97.74ms +step:1225/1670 train_time:119737ms step_avg:97.74ms +step:1226/1670 train_time:119835ms step_avg:97.74ms +step:1227/1670 train_time:119932ms step_avg:97.74ms +step:1228/1670 train_time:120030ms step_avg:97.74ms +step:1229/1670 train_time:120128ms step_avg:97.74ms +step:1230/1670 train_time:120226ms step_avg:97.74ms +step:1231/1670 train_time:120324ms step_avg:97.74ms +step:1232/1670 train_time:120422ms step_avg:97.75ms +step:1233/1670 train_time:120521ms step_avg:97.75ms +step:1234/1670 train_time:120619ms step_avg:97.75ms +step:1235/1670 train_time:120717ms step_avg:97.75ms +step:1236/1670 train_time:120815ms step_avg:97.75ms +step:1237/1670 train_time:120913ms step_avg:97.75ms +step:1238/1670 train_time:121010ms step_avg:97.75ms +step:1239/1670 train_time:121107ms step_avg:97.75ms +step:1240/1670 train_time:121205ms step_avg:97.75ms +step:1241/1670 train_time:121304ms step_avg:97.75ms +step:1242/1670 train_time:121403ms step_avg:97.75ms +step:1243/1670 train_time:121501ms step_avg:97.75ms +step:1244/1670 train_time:121599ms step_avg:97.75ms +step:1245/1670 train_time:121697ms step_avg:97.75ms +step:1246/1670 train_time:121795ms step_avg:97.75ms +step:1247/1670 train_time:121892ms step_avg:97.75ms +step:1248/1670 train_time:121989ms step_avg:97.75ms +step:1249/1670 train_time:122087ms step_avg:97.75ms +step:1250/1670 train_time:122185ms step_avg:97.75ms +step:1250/1670 val_loss:3.3817 train_time:122282ms step_avg:97.83ms +step:1251/1670 train_time:122303ms step_avg:97.76ms +step:1252/1670 train_time:122386ms step_avg:97.75ms +step:1253/1670 train_time:122487ms step_avg:97.75ms +step:1254/1670 train_time:122585ms step_avg:97.75ms +step:1255/1670 train_time:122681ms step_avg:97.75ms +step:1256/1670 train_time:122778ms step_avg:97.75ms +step:1257/1670 train_time:122875ms step_avg:97.75ms +step:1258/1670 train_time:122972ms step_avg:97.75ms +step:1259/1670 train_time:123068ms step_avg:97.75ms +step:1260/1670 train_time:123165ms step_avg:97.75ms +step:1261/1670 train_time:123264ms step_avg:97.75ms +step:1262/1670 train_time:123362ms step_avg:97.75ms +step:1263/1670 train_time:123461ms step_avg:97.75ms +step:1264/1670 train_time:123560ms step_avg:97.75ms +step:1265/1670 train_time:123658ms step_avg:97.75ms +step:1266/1670 train_time:123755ms step_avg:97.75ms +step:1267/1670 train_time:123852ms step_avg:97.75ms +step:1268/1670 train_time:123949ms step_avg:97.75ms +step:1269/1670 train_time:124046ms step_avg:97.75ms +step:1270/1670 train_time:124143ms step_avg:97.75ms +step:1271/1670 train_time:124240ms step_avg:97.75ms +step:1272/1670 train_time:124339ms step_avg:97.75ms +step:1273/1670 train_time:124438ms step_avg:97.75ms +step:1274/1670 train_time:124725ms step_avg:97.90ms +step:1275/1670 train_time:124927ms step_avg:97.98ms +step:1276/1670 train_time:125024ms step_avg:97.98ms +step:1277/1670 train_time:125120ms step_avg:97.98ms +step:1278/1670 train_time:125217ms step_avg:97.98ms +step:1279/1670 train_time:125313ms step_avg:97.98ms +step:1280/1670 train_time:125411ms step_avg:97.98ms +step:1281/1670 train_time:125507ms step_avg:97.98ms +step:1282/1670 train_time:125604ms step_avg:97.97ms +step:1283/1670 train_time:125700ms step_avg:97.97ms +step:1284/1670 train_time:125800ms step_avg:97.97ms +step:1285/1670 train_time:125902ms step_avg:97.98ms +step:1286/1670 train_time:126001ms step_avg:97.98ms +step:1287/1670 train_time:126099ms step_avg:97.98ms +step:1288/1670 train_time:126197ms step_avg:97.98ms +step:1289/1670 train_time:126295ms step_avg:97.98ms +step:1290/1670 train_time:126393ms step_avg:97.98ms +step:1291/1670 train_time:126490ms step_avg:97.98ms +step:1292/1670 train_time:126588ms step_avg:97.98ms +step:1293/1670 train_time:126685ms step_avg:97.98ms +step:1294/1670 train_time:126783ms step_avg:97.98ms +step:1295/1670 train_time:126880ms step_avg:97.98ms +step:1296/1670 train_time:126980ms step_avg:97.98ms +step:1297/1670 train_time:127078ms step_avg:97.98ms +step:1298/1670 train_time:127175ms step_avg:97.98ms +step:1299/1670 train_time:127273ms step_avg:97.98ms +step:1300/1670 train_time:127370ms step_avg:97.98ms +step:1301/1670 train_time:127468ms step_avg:97.98ms +step:1302/1670 train_time:127565ms step_avg:97.98ms +step:1303/1670 train_time:127662ms step_avg:97.98ms +step:1304/1670 train_time:127760ms step_avg:97.98ms +step:1305/1670 train_time:127857ms step_avg:97.97ms +step:1306/1670 train_time:127956ms step_avg:97.98ms +step:1307/1670 train_time:128056ms step_avg:97.98ms +step:1308/1670 train_time:128154ms step_avg:97.98ms +step:1309/1670 train_time:128252ms step_avg:97.98ms +step:1310/1670 train_time:128350ms step_avg:97.98ms +step:1311/1670 train_time:128448ms step_avg:97.98ms +step:1312/1670 train_time:128544ms step_avg:97.98ms +step:1313/1670 train_time:128642ms step_avg:97.98ms +step:1314/1670 train_time:128739ms step_avg:97.98ms +step:1315/1670 train_time:128837ms step_avg:97.97ms +step:1316/1670 train_time:128934ms step_avg:97.97ms +step:1317/1670 train_time:129033ms step_avg:97.97ms +step:1318/1670 train_time:129133ms step_avg:97.98ms +step:1319/1670 train_time:129231ms step_avg:97.98ms +step:1320/1670 train_time:129329ms step_avg:97.98ms +step:1321/1670 train_time:129426ms step_avg:97.98ms +step:1322/1670 train_time:129523ms step_avg:97.98ms +step:1323/1670 train_time:129622ms step_avg:97.98ms +step:1324/1670 train_time:129718ms step_avg:97.97ms +step:1325/1670 train_time:129816ms step_avg:97.97ms +step:1326/1670 train_time:129914ms step_avg:97.97ms +step:1327/1670 train_time:130013ms step_avg:97.98ms +step:1328/1670 train_time:130111ms step_avg:97.97ms +step:1329/1670 train_time:130209ms step_avg:97.97ms +step:1330/1670 train_time:130306ms step_avg:97.97ms +step:1331/1670 train_time:130403ms step_avg:97.97ms +step:1332/1670 train_time:130500ms step_avg:97.97ms +step:1333/1670 train_time:130597ms step_avg:97.97ms +step:1334/1670 train_time:130695ms step_avg:97.97ms +step:1335/1670 train_time:130794ms step_avg:97.97ms +step:1336/1670 train_time:130892ms step_avg:97.97ms +step:1337/1670 train_time:130989ms step_avg:97.97ms +step:1338/1670 train_time:131087ms step_avg:97.97ms +step:1339/1670 train_time:131184ms step_avg:97.97ms +step:1340/1670 train_time:131281ms step_avg:97.97ms +step:1341/1670 train_time:131379ms step_avg:97.97ms +step:1342/1670 train_time:131477ms step_avg:97.97ms +step:1343/1670 train_time:131575ms step_avg:97.97ms +step:1344/1670 train_time:131674ms step_avg:97.97ms +step:1345/1670 train_time:131772ms step_avg:97.97ms +step:1346/1670 train_time:131871ms step_avg:97.97ms +step:1347/1670 train_time:131968ms step_avg:97.97ms +step:1348/1670 train_time:132066ms step_avg:97.97ms +step:1349/1670 train_time:132164ms step_avg:97.97ms +step:1350/1670 train_time:132261ms step_avg:97.97ms +step:1351/1670 train_time:132359ms step_avg:97.97ms +step:1352/1670 train_time:132457ms step_avg:97.97ms +step:1353/1670 train_time:132556ms step_avg:97.97ms +step:1354/1670 train_time:132654ms step_avg:97.97ms +step:1355/1670 train_time:132753ms step_avg:97.97ms +step:1356/1670 train_time:132852ms step_avg:97.97ms +step:1357/1670 train_time:132950ms step_avg:97.97ms +step:1358/1670 train_time:133048ms step_avg:97.97ms +step:1359/1670 train_time:133146ms step_avg:97.97ms +step:1360/1670 train_time:133245ms step_avg:97.97ms +step:1361/1670 train_time:133341ms step_avg:97.97ms +step:1362/1670 train_time:133439ms step_avg:97.97ms +step:1363/1670 train_time:133537ms step_avg:97.97ms +step:1364/1670 train_time:133635ms step_avg:97.97ms +step:1365/1670 train_time:133733ms step_avg:97.97ms +step:1366/1670 train_time:133831ms step_avg:97.97ms +step:1367/1670 train_time:133929ms step_avg:97.97ms +step:1368/1670 train_time:134027ms step_avg:97.97ms +step:1369/1670 train_time:134126ms step_avg:97.97ms +step:1370/1670 train_time:134224ms step_avg:97.97ms +step:1371/1670 train_time:134322ms step_avg:97.97ms +step:1372/1670 train_time:134419ms step_avg:97.97ms +step:1373/1670 train_time:134517ms step_avg:97.97ms +step:1374/1670 train_time:134614ms step_avg:97.97ms +step:1375/1670 train_time:134713ms step_avg:97.97ms +step:1375/1670 val_loss:3.3437 train_time:134810ms step_avg:98.04ms +step:1376/1670 train_time:134830ms step_avg:97.99ms +step:1377/1670 train_time:134913ms step_avg:97.98ms +step:1378/1670 train_time:135013ms step_avg:97.98ms +step:1379/1670 train_time:135110ms step_avg:97.98ms +step:1380/1670 train_time:135208ms step_avg:97.98ms +step:1381/1670 train_time:135307ms step_avg:97.98ms +step:1382/1670 train_time:135404ms step_avg:97.98ms +step:1383/1670 train_time:135501ms step_avg:97.98ms +step:1384/1670 train_time:135598ms step_avg:97.98ms +step:1385/1670 train_time:135695ms step_avg:97.97ms +step:1386/1670 train_time:135793ms step_avg:97.97ms +step:1387/1670 train_time:135892ms step_avg:97.98ms +step:1388/1670 train_time:135991ms step_avg:97.98ms +step:1389/1670 train_time:136089ms step_avg:97.98ms +step:1390/1670 train_time:136187ms step_avg:97.98ms +step:1391/1670 train_time:136285ms step_avg:97.98ms +step:1392/1670 train_time:136382ms step_avg:97.98ms +step:1393/1670 train_time:136480ms step_avg:97.98ms +step:1394/1670 train_time:136577ms step_avg:97.97ms +step:1395/1670 train_time:136674ms step_avg:97.97ms +step:1396/1670 train_time:136772ms step_avg:97.97ms +step:1397/1670 train_time:136870ms step_avg:97.97ms +step:1398/1670 train_time:136968ms step_avg:97.97ms +step:1399/1670 train_time:137068ms step_avg:97.98ms +step:1400/1670 train_time:137167ms step_avg:97.98ms +step:1401/1670 train_time:137265ms step_avg:97.98ms +step:1402/1670 train_time:137363ms step_avg:97.98ms +step:1403/1670 train_time:137460ms step_avg:97.98ms +step:1404/1670 train_time:137557ms step_avg:97.98ms +step:1405/1670 train_time:137654ms step_avg:97.97ms +step:1406/1670 train_time:137751ms step_avg:97.97ms +step:1407/1670 train_time:137850ms step_avg:97.97ms +step:1408/1670 train_time:137950ms step_avg:97.98ms +step:1409/1670 train_time:138049ms step_avg:97.98ms +step:1410/1670 train_time:138148ms step_avg:97.98ms +step:1411/1670 train_time:138246ms step_avg:97.98ms +step:1412/1670 train_time:138344ms step_avg:97.98ms +step:1413/1670 train_time:138442ms step_avg:97.98ms +step:1414/1670 train_time:138541ms step_avg:97.98ms +step:1415/1670 train_time:138639ms step_avg:97.98ms +step:1416/1670 train_time:138737ms step_avg:97.98ms +step:1417/1670 train_time:138833ms step_avg:97.98ms +step:1418/1670 train_time:138932ms step_avg:97.98ms +step:1419/1670 train_time:139030ms step_avg:97.98ms +step:1420/1670 train_time:139128ms step_avg:97.98ms +step:1421/1670 train_time:139226ms step_avg:97.98ms +step:1422/1670 train_time:139325ms step_avg:97.98ms +step:1423/1670 train_time:139423ms step_avg:97.98ms +step:1424/1670 train_time:139521ms step_avg:97.98ms +step:1425/1670 train_time:139619ms step_avg:97.98ms +step:1426/1670 train_time:139717ms step_avg:97.98ms +step:1427/1670 train_time:139815ms step_avg:97.98ms +step:1428/1670 train_time:139912ms step_avg:97.98ms +step:1429/1670 train_time:140010ms step_avg:97.98ms +step:1430/1670 train_time:140109ms step_avg:97.98ms +step:1431/1670 train_time:140207ms step_avg:97.98ms +step:1432/1670 train_time:140305ms step_avg:97.98ms +step:1433/1670 train_time:140403ms step_avg:97.98ms +step:1434/1670 train_time:140501ms step_avg:97.98ms +step:1435/1670 train_time:140598ms step_avg:97.98ms +step:1436/1670 train_time:140695ms step_avg:97.98ms +step:1437/1670 train_time:140792ms step_avg:97.98ms +step:1438/1670 train_time:140890ms step_avg:97.98ms +step:1439/1670 train_time:140987ms step_avg:97.98ms +step:1440/1670 train_time:141086ms step_avg:97.98ms +step:1441/1670 train_time:141185ms step_avg:97.98ms +step:1442/1670 train_time:141285ms step_avg:97.98ms +step:1443/1670 train_time:141383ms step_avg:97.98ms +step:1444/1670 train_time:141480ms step_avg:97.98ms +step:1445/1670 train_time:141578ms step_avg:97.98ms +step:1446/1670 train_time:141676ms step_avg:97.98ms +step:1447/1670 train_time:141773ms step_avg:97.98ms +step:1448/1670 train_time:141871ms step_avg:97.98ms +step:1449/1670 train_time:141969ms step_avg:97.98ms +step:1450/1670 train_time:142068ms step_avg:97.98ms +step:1451/1670 train_time:142165ms step_avg:97.98ms +step:1452/1670 train_time:142264ms step_avg:97.98ms +step:1453/1670 train_time:142362ms step_avg:97.98ms +step:1454/1670 train_time:142460ms step_avg:97.98ms +step:1455/1670 train_time:142559ms step_avg:97.98ms +step:1456/1670 train_time:142657ms step_avg:97.98ms +step:1457/1670 train_time:142755ms step_avg:97.98ms +step:1458/1670 train_time:142852ms step_avg:97.98ms +step:1459/1670 train_time:142949ms step_avg:97.98ms +step:1460/1670 train_time:143047ms step_avg:97.98ms +step:1461/1670 train_time:143146ms step_avg:97.98ms +step:1462/1670 train_time:143245ms step_avg:97.98ms +step:1463/1670 train_time:143344ms step_avg:97.98ms +step:1464/1670 train_time:143442ms step_avg:97.98ms +step:1465/1670 train_time:143541ms step_avg:97.98ms +step:1466/1670 train_time:143640ms step_avg:97.98ms +step:1467/1670 train_time:143738ms step_avg:97.98ms +step:1468/1670 train_time:143836ms step_avg:97.98ms +step:1469/1670 train_time:143933ms step_avg:97.98ms +step:1470/1670 train_time:144031ms step_avg:97.98ms +step:1471/1670 train_time:144129ms step_avg:97.98ms +step:1472/1670 train_time:144226ms step_avg:97.98ms +step:1473/1670 train_time:144324ms step_avg:97.98ms +step:1474/1670 train_time:144422ms step_avg:97.98ms +step:1475/1670 train_time:144521ms step_avg:97.98ms +step:1476/1670 train_time:144620ms step_avg:97.98ms +step:1477/1670 train_time:144718ms step_avg:97.98ms +step:1478/1670 train_time:144816ms step_avg:97.98ms +step:1479/1670 train_time:144913ms step_avg:97.98ms +step:1480/1670 train_time:145011ms step_avg:97.98ms +step:1481/1670 train_time:145108ms step_avg:97.98ms +step:1482/1670 train_time:145206ms step_avg:97.98ms +step:1483/1670 train_time:145303ms step_avg:97.98ms +step:1484/1670 train_time:145401ms step_avg:97.98ms +step:1485/1670 train_time:145673ms step_avg:98.10ms +step:1486/1670 train_time:145862ms step_avg:98.16ms +step:1487/1670 train_time:145958ms step_avg:98.16ms +step:1488/1670 train_time:146054ms step_avg:98.15ms +step:1489/1670 train_time:146150ms step_avg:98.15ms +step:1490/1670 train_time:146247ms step_avg:98.15ms +step:1491/1670 train_time:146344ms step_avg:98.15ms +step:1492/1670 train_time:146441ms step_avg:98.15ms +step:1493/1670 train_time:146537ms step_avg:98.15ms +step:1494/1670 train_time:146633ms step_avg:98.15ms +step:1495/1670 train_time:146730ms step_avg:98.15ms +step:1496/1670 train_time:146835ms step_avg:98.15ms +step:1497/1670 train_time:146937ms step_avg:98.15ms +step:1498/1670 train_time:147034ms step_avg:98.15ms +step:1499/1670 train_time:147130ms step_avg:98.15ms +step:1500/1670 train_time:147227ms step_avg:98.15ms +step:1500/1670 val_loss:3.3116 train_time:147324ms step_avg:98.22ms +step:1501/1670 train_time:147345ms step_avg:98.16ms +step:1502/1670 train_time:147430ms step_avg:98.16ms +step:1503/1670 train_time:147530ms step_avg:98.16ms +step:1504/1670 train_time:147628ms step_avg:98.16ms +step:1505/1670 train_time:147725ms step_avg:98.16ms +step:1506/1670 train_time:147821ms step_avg:98.15ms +step:1507/1670 train_time:147919ms step_avg:98.15ms +step:1508/1670 train_time:148016ms step_avg:98.15ms +step:1509/1670 train_time:148116ms step_avg:98.15ms +step:1510/1670 train_time:148214ms step_avg:98.16ms +step:1511/1670 train_time:148316ms step_avg:98.16ms +step:1512/1670 train_time:148417ms step_avg:98.16ms +step:1513/1670 train_time:148516ms step_avg:98.16ms +step:1514/1670 train_time:148616ms step_avg:98.16ms +step:1515/1670 train_time:148715ms step_avg:98.16ms +step:1516/1670 train_time:148813ms step_avg:98.16ms +step:1517/1670 train_time:148910ms step_avg:98.16ms +step:1518/1670 train_time:149007ms step_avg:98.16ms +step:1519/1670 train_time:149104ms step_avg:98.16ms +step:1520/1670 train_time:149201ms step_avg:98.16ms +step:1521/1670 train_time:149299ms step_avg:98.16ms +step:1522/1670 train_time:149398ms step_avg:98.16ms +step:1523/1670 train_time:149496ms step_avg:98.16ms +step:1524/1670 train_time:149595ms step_avg:98.16ms +step:1525/1670 train_time:149694ms step_avg:98.16ms +step:1526/1670 train_time:149792ms step_avg:98.16ms +step:1527/1670 train_time:149889ms step_avg:98.16ms +step:1528/1670 train_time:149987ms step_avg:98.16ms +step:1529/1670 train_time:150084ms step_avg:98.16ms +step:1530/1670 train_time:150182ms step_avg:98.16ms +step:1531/1670 train_time:150280ms step_avg:98.16ms +step:1532/1670 train_time:150378ms step_avg:98.16ms +step:1533/1670 train_time:150476ms step_avg:98.16ms +step:1534/1670 train_time:150576ms step_avg:98.16ms +step:1535/1670 train_time:150674ms step_avg:98.16ms +step:1536/1670 train_time:150773ms step_avg:98.16ms +step:1537/1670 train_time:150870ms step_avg:98.16ms +step:1538/1670 train_time:150969ms step_avg:98.16ms +step:1539/1670 train_time:151066ms step_avg:98.16ms +step:1540/1670 train_time:151163ms step_avg:98.16ms +step:1541/1670 train_time:151260ms step_avg:98.16ms +step:1542/1670 train_time:151357ms step_avg:98.16ms +step:1543/1670 train_time:151457ms step_avg:98.16ms +step:1544/1670 train_time:151557ms step_avg:98.16ms +step:1545/1670 train_time:151655ms step_avg:98.16ms +step:1546/1670 train_time:151753ms step_avg:98.16ms +step:1547/1670 train_time:151852ms step_avg:98.16ms +step:1548/1670 train_time:151950ms step_avg:98.16ms +step:1549/1670 train_time:152048ms step_avg:98.16ms +step:1550/1670 train_time:152146ms step_avg:98.16ms +step:1551/1670 train_time:152244ms step_avg:98.16ms +step:1552/1670 train_time:152341ms step_avg:98.16ms +step:1553/1670 train_time:152438ms step_avg:98.16ms +step:1554/1670 train_time:152536ms step_avg:98.16ms +step:1555/1670 train_time:152635ms step_avg:98.16ms +step:1556/1670 train_time:152732ms step_avg:98.16ms +step:1557/1670 train_time:152830ms step_avg:98.16ms +step:1558/1670 train_time:152929ms step_avg:98.16ms +step:1559/1670 train_time:153027ms step_avg:98.16ms +step:1560/1670 train_time:153125ms step_avg:98.16ms +step:1561/1670 train_time:153224ms step_avg:98.16ms +step:1562/1670 train_time:153321ms step_avg:98.16ms +step:1563/1670 train_time:153420ms step_avg:98.16ms +step:1564/1670 train_time:153518ms step_avg:98.16ms +step:1565/1670 train_time:153616ms step_avg:98.16ms +step:1566/1670 train_time:153713ms step_avg:98.16ms +step:1567/1670 train_time:153812ms step_avg:98.16ms +step:1568/1670 train_time:153909ms step_avg:98.16ms +step:1569/1670 train_time:154007ms step_avg:98.16ms +step:1570/1670 train_time:154105ms step_avg:98.16ms +step:1571/1670 train_time:154203ms step_avg:98.16ms +step:1572/1670 train_time:154300ms step_avg:98.16ms +step:1573/1670 train_time:154398ms step_avg:98.16ms +step:1574/1670 train_time:154497ms step_avg:98.16ms +step:1575/1670 train_time:154595ms step_avg:98.16ms +step:1576/1670 train_time:154693ms step_avg:98.16ms +step:1577/1670 train_time:154790ms step_avg:98.15ms +step:1578/1670 train_time:154888ms step_avg:98.15ms +step:1579/1670 train_time:154986ms step_avg:98.15ms +step:1580/1670 train_time:155084ms step_avg:98.15ms +step:1581/1670 train_time:155182ms step_avg:98.15ms +step:1582/1670 train_time:155280ms step_avg:98.15ms +step:1583/1670 train_time:155377ms step_avg:98.15ms +step:1584/1670 train_time:155475ms step_avg:98.15ms +step:1585/1670 train_time:155572ms step_avg:98.15ms +step:1586/1670 train_time:155670ms step_avg:98.15ms +step:1587/1670 train_time:155768ms step_avg:98.15ms +step:1588/1670 train_time:155865ms step_avg:98.15ms +step:1589/1670 train_time:155963ms step_avg:98.15ms +step:1590/1670 train_time:156061ms step_avg:98.15ms +step:1591/1670 train_time:156159ms step_avg:98.15ms +step:1592/1670 train_time:156257ms step_avg:98.15ms +step:1593/1670 train_time:156355ms step_avg:98.15ms +step:1594/1670 train_time:156453ms step_avg:98.15ms +step:1595/1670 train_time:156551ms step_avg:98.15ms +step:1596/1670 train_time:156648ms step_avg:98.15ms +step:1597/1670 train_time:156746ms step_avg:98.15ms +step:1598/1670 train_time:156843ms step_avg:98.15ms +step:1599/1670 train_time:156940ms step_avg:98.15ms +step:1600/1670 train_time:157039ms step_avg:98.15ms +step:1601/1670 train_time:157136ms step_avg:98.15ms +step:1602/1670 train_time:157235ms step_avg:98.15ms +step:1603/1670 train_time:157334ms step_avg:98.15ms +step:1604/1670 train_time:157432ms step_avg:98.15ms +step:1605/1670 train_time:157530ms step_avg:98.15ms +step:1606/1670 train_time:157628ms step_avg:98.15ms +step:1607/1670 train_time:157725ms step_avg:98.15ms +step:1608/1670 train_time:157823ms step_avg:98.15ms +step:1609/1670 train_time:157920ms step_avg:98.15ms +step:1610/1670 train_time:158018ms step_avg:98.15ms +step:1611/1670 train_time:158115ms step_avg:98.15ms +step:1612/1670 train_time:158214ms step_avg:98.15ms +step:1613/1670 train_time:158313ms step_avg:98.15ms +step:1614/1670 train_time:158411ms step_avg:98.15ms +step:1615/1670 train_time:158508ms step_avg:98.15ms +step:1616/1670 train_time:158606ms step_avg:98.15ms +step:1617/1670 train_time:158703ms step_avg:98.15ms +step:1618/1670 train_time:158801ms step_avg:98.15ms +step:1619/1670 train_time:158899ms step_avg:98.15ms +step:1620/1670 train_time:158996ms step_avg:98.15ms +step:1621/1670 train_time:159094ms step_avg:98.15ms +step:1622/1670 train_time:159193ms step_avg:98.15ms +step:1623/1670 train_time:159291ms step_avg:98.15ms +step:1624/1670 train_time:159389ms step_avg:98.15ms +step:1625/1670 train_time:159487ms step_avg:98.15ms +step:1625/1670 val_loss:3.2846 train_time:159584ms step_avg:98.21ms +step:1626/1670 train_time:159606ms step_avg:98.16ms +step:1627/1670 train_time:159689ms step_avg:98.15ms +step:1628/1670 train_time:159788ms step_avg:98.15ms +step:1629/1670 train_time:159886ms step_avg:98.15ms +step:1630/1670 train_time:159983ms step_avg:98.15ms +step:1631/1670 train_time:160080ms step_avg:98.15ms +step:1632/1670 train_time:160177ms step_avg:98.15ms +step:1633/1670 train_time:160274ms step_avg:98.15ms +step:1634/1670 train_time:160371ms step_avg:98.15ms +step:1635/1670 train_time:160467ms step_avg:98.15ms +step:1636/1670 train_time:160566ms step_avg:98.15ms +step:1637/1670 train_time:160668ms step_avg:98.15ms +step:1638/1670 train_time:160767ms step_avg:98.15ms +step:1639/1670 train_time:160865ms step_avg:98.15ms +step:1640/1670 train_time:160963ms step_avg:98.15ms +step:1641/1670 train_time:161061ms step_avg:98.15ms +step:1642/1670 train_time:161160ms step_avg:98.15ms +step:1643/1670 train_time:161259ms step_avg:98.15ms +step:1644/1670 train_time:161357ms step_avg:98.15ms +step:1645/1670 train_time:161454ms step_avg:98.15ms +step:1646/1670 train_time:161552ms step_avg:98.15ms +step:1647/1670 train_time:161650ms step_avg:98.15ms +step:1648/1670 train_time:161748ms step_avg:98.15ms +step:1649/1670 train_time:161845ms step_avg:98.15ms +step:1650/1670 train_time:161944ms step_avg:98.15ms +step:1651/1670 train_time:162041ms step_avg:98.15ms +step:1652/1670 train_time:162140ms step_avg:98.15ms +step:1653/1670 train_time:162238ms step_avg:98.15ms +step:1654/1670 train_time:162335ms step_avg:98.15ms +step:1655/1670 train_time:162433ms step_avg:98.15ms +step:1656/1670 train_time:162532ms step_avg:98.15ms +step:1657/1670 train_time:162630ms step_avg:98.15ms +step:1658/1670 train_time:162728ms step_avg:98.15ms +step:1659/1670 train_time:162825ms step_avg:98.15ms +step:1660/1670 train_time:162922ms step_avg:98.15ms +step:1661/1670 train_time:163021ms step_avg:98.15ms +step:1662/1670 train_time:163119ms step_avg:98.15ms +step:1663/1670 train_time:163216ms step_avg:98.15ms +step:1664/1670 train_time:163313ms step_avg:98.14ms +step:1665/1670 train_time:163410ms step_avg:98.14ms +step:1666/1670 train_time:163508ms step_avg:98.14ms +step:1667/1670 train_time:163607ms step_avg:98.14ms +step:1668/1670 train_time:163706ms step_avg:98.14ms +step:1669/1670 train_time:163803ms step_avg:98.14ms +step:1670/1670 train_time:163901ms step_avg:98.14ms +step:1670/1670 val_loss:3.2766 train_time:163998ms step_avg:98.20ms +peak memory allocated: 34757 MiB reserved: 49516 MiB diff --git a/records/090325_FA3/README.md b/records/090325_FA3/README.md new file mode 100644 index 000000000..e45e52a3a --- /dev/null +++ b/records/090325_FA3/README.md @@ -0,0 +1,133 @@ +# New record 09/03/25 + +This submission includes recent WR changes by +@ClassicLarry [(08/23/25)](https://github.com/ClassicLarry/modded-nanogpt/tree/master/records/082325_SparseAttnGate) +and @byronxu99 [(07/18/25)](https://github.com/KellerJordan/modded-nanogpt/pull/109). + +Additionally, it has been updated after helpful discussion with @ClassicLarry and @YouJiacheng. + +The main idea of this record is to use [Flash Attention v3](https://github.com/Dao-AILab/flash-attention) instead of Flex Attention. +The official version of this module is incompatible with `torch.compile` and causes graph breaks. +However, a [recent PR](https://github.com/Dao-AILab/flash-attention/pull/1769) by +@guilhermeleobas addresses this issue. + + +## Timing and Validation + +In 1670 training steps, this run achieves a loss <3.28 (`p=0.0001`) in 163.84 seconds on average, validated over 7 runs. + +``` +import torch +import numpy as np + +accs = [ + 3.2771, 3.2755, 3.2760, 3.2766, 3.2778, 3.2774, 3.2780 +] + +times = [ + 163.871, 163.621, 163.848, 163.998, 163.897, 164.016, 163.618 +] + +print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue) +# p=0.0001 + +print(f"{np.mean(times):.4f}") +# 163.8384 +``` + +In my timing, this is a 4.3 second mean improvement over https://github.com/KellerJordan/modded-nanogpt/pull/117. +The number of steps can also probably be brought down by 5-10 while achieving loss <3.28. + +I used SXM5 8 x H100 via Prime Intellect for validation compute. + +## Further Details + +### Motivation + +Flash Attention v3 achieves greater SM utilization on Hopper GPUs than Flash Attention v2. +Flash Attention 3 is significantly faster than Flex Attention on batched inputs, and this gap increases as we increase the number of sequences per batch: + +Flash vs Flex Attention varying #sequences/batch + +In order to train with document masking, we use Flex Attention's `flash_attn_varlen_func` (suggested by @YouJiacheng). +We keep the number of tokens per step fixed (`393216`) but pack a variable number of sequences in each batch, +clipping the maximum length of each sequence to `args.train_max_seq_len = 2048`. + +WR#26 by @ClassicLarry found that validation loss decreases when we train only on sequences beginning with the Beginning of Sequence token (``). + + +### Flash Attention 3 + + +As mentioned above, we need to use an unmerged PR in order to use FA3 with `torch.compile`. +You can build the wheel like so: + +``` +pip install -U pip wheel setuptools ninja numpy packaging psutil + +git clone https://github.com/guilhermeleobas/flash-attention.git +cd flash-attention/hopper +git switch guilhermeleobas/fa3-compile + +export MAX_JOBS=32 # Can increase based on machine +export FLASH_ATTENTION_FORCE_BUILD=TRUE # skip prebuilt wheel fetch +export FLASH_ATTENTION_DISABLE_SM80=TRUE # Hopper-only +export FLASH_ATTENTION_DISABLE_FP16=TRUE # leave BF16, FP8 +export FLASH_ATTENTION_DISABLE_HDIM64=TRUE # NanoGPT only uses HDIM = 128 +export FLASH_ATTENTION_DISABLE_HDIM96=TRUE +export FLASH_ATTENTION_DISABLE_HDIM192=TRUE +export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + +python setup.py bdist_wheel +``` + +Additionally, I have uploaded a prebuilt wheel [here](https://github.com/varunneal/flash-attention/releases/tag/v3.0.0b1-alpha). +Downloading this wheel and installing it via pip is likely to be fairly fast. + +For exact reproduction, I recommend that you install Torch Nightly 2.9.0.dev20250718 and +install the FA3 wheel afterward: + +``` +pip install --pre "torch==2.9.0.dev20250718+cu126" --index-url https://download.pytorch.org/whl/nightly/cu126 + +pip install /path/to/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl +``` + +For me, Torch Nightly 2.9.0.dev20250713 was incompatible with PR#109. + +### Attention Masks + +Flash Attention exposes the parameter `window_size` where we can specify the number of tokens to attend to. +Unfortunately, it expects this value to be an int, so varying it will cause a `torch.compile` to +create a new graph. As such, I decreased the number of window sizes over the course of the run. + +I kept the existing long-short sliding window block mask pattern, as well as the idea +that the window sizes should linearly increase over the length of the training run. +To aid with this, I created a hyperparameter `ws_schedule` and `get_ws(step)`. +I additionally added the size of blocks in a window as a hyperparameter `block_size=128`. + +I have picked a linear schedule with three steps: `ws_schedule=(3, 7, 11)`. +Each graph needs to be warmed up separately. I have increased the number +of warmup steps from `10` to `30`. The compile time is dominated by the first iteration +so this will take approximately `len(ws_schedule)` times longer than before. + + +Document masks are implemented by specifying the start and end of each sequence in `cu_seqlens_*`. +In order for the tensor sizes to be fixed, we pad `cu_seqlens_*` to be a fixed length of a length larger +than the number of documents we may ever expect in a single input batch. + +At training time, sequences are clipped to `args.max_seq_len` tokens. +This clipping helps pack a greater diversity of sequences per batch. +I believe this change to be responsible for the decrease of ~25 training steps. + +In order to implement the above, I have created the helper class `BOSFinder`. + +### Potential Improvements + +- Batch size scheduling: Previously, the block mask acted as a proxy for batch size. +Now block size can be controlled explicitly and sequenced according to critical batch size theory. +I have added code in `distributed_data_generator` that allows for changing the +batch size max sequence length, and grad_accum_steps yielded after the generator is created. +- The current block mask window schedule `(3, 7, 11)` can almost certainly be improved upon. +- Hyperparameter tuning might change with smaller sequence length. Rotary base, validation sequence length, learning rates +etc. should be re-tuned. I haven't done that for this run. diff --git a/records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt b/records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt new file mode 100644 index 000000000..6f1aee1e5 --- /dev/null +++ b/records/090325_FA3/ce3400f2-2ca1-4e0e-a784-089451df1913.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 20:09:36 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 34C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 54231 C /usr/bin/python 0MiB | +| 0 N/A N/A 54232 C /usr/bin/python 0MiB | +| 0 N/A N/A 54233 C /usr/bin/python 0MiB | +| 0 N/A N/A 54234 C /usr/bin/python 0MiB | +| 0 N/A N/A 54235 C /usr/bin/python 0MiB | +| 0 N/A N/A 54236 C /usr/bin/python 0MiB | +| 0 N/A N/A 54237 C /usr/bin/python 0MiB | +| 0 N/A N/A 54238 C /usr/bin/python 0MiB | +| 1 N/A N/A 54232 C /usr/bin/python 0MiB | +| 2 N/A N/A 54233 C /usr/bin/python 0MiB | +| 3 N/A N/A 54234 C /usr/bin/python 0MiB | +| 4 N/A N/A 54235 C /usr/bin/python 0MiB | +| 5 N/A N/A 54236 C /usr/bin/python 0MiB | +| 6 N/A N/A 54237 C /usr/bin/python 0MiB | +| 7 N/A N/A 54238 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:385ms step_avg:385.21ms +step:2/1670 train_time:406ms step_avg:202.85ms +step:3/1670 train_time:479ms step_avg:159.59ms +step:4/1670 train_time:572ms step_avg:143.09ms +step:5/1670 train_time:667ms step_avg:133.32ms +step:6/1670 train_time:762ms step_avg:126.96ms +step:7/1670 train_time:856ms step_avg:122.35ms +step:8/1670 train_time:951ms step_avg:118.90ms +step:9/1670 train_time:1046ms step_avg:116.23ms +step:10/1670 train_time:1140ms step_avg:114.04ms +step:11/1670 train_time:1235ms step_avg:112.31ms +step:12/1670 train_time:1332ms step_avg:110.99ms +step:13/1670 train_time:1431ms step_avg:110.10ms +step:14/1670 train_time:1529ms step_avg:109.22ms +step:15/1670 train_time:1625ms step_avg:108.31ms +step:16/1670 train_time:1720ms step_avg:107.53ms +step:17/1670 train_time:1816ms step_avg:106.82ms +step:18/1670 train_time:1911ms step_avg:106.15ms +step:19/1670 train_time:2006ms step_avg:105.59ms +step:20/1670 train_time:2102ms step_avg:105.09ms +step:21/1670 train_time:2197ms step_avg:104.61ms +step:22/1670 train_time:2293ms step_avg:104.22ms +step:23/1670 train_time:2390ms step_avg:103.91ms +step:24/1670 train_time:2488ms step_avg:103.65ms +step:25/1670 train_time:2584ms step_avg:103.38ms +step:26/1670 train_time:2680ms step_avg:103.07ms +step:27/1670 train_time:2775ms step_avg:102.79ms +step:28/1670 train_time:2870ms step_avg:102.51ms +step:29/1670 train_time:2966ms step_avg:102.28ms +step:30/1670 train_time:3061ms step_avg:102.04ms +step:31/1670 train_time:3157ms step_avg:101.85ms +step:32/1670 train_time:3253ms step_avg:101.65ms +step:33/1670 train_time:3348ms step_avg:101.47ms +step:34/1670 train_time:3445ms step_avg:101.32ms +step:35/1670 train_time:3542ms step_avg:101.20ms +step:36/1670 train_time:3637ms step_avg:101.03ms +step:37/1670 train_time:3733ms step_avg:100.88ms +step:38/1670 train_time:3829ms step_avg:100.77ms +step:39/1670 train_time:3925ms step_avg:100.64ms +step:40/1670 train_time:4021ms step_avg:100.53ms +step:41/1670 train_time:4117ms step_avg:100.41ms +step:42/1670 train_time:4212ms step_avg:100.29ms +step:43/1670 train_time:4308ms step_avg:100.19ms +step:44/1670 train_time:4404ms step_avg:100.09ms +step:45/1670 train_time:4500ms step_avg:100.00ms +step:46/1670 train_time:4596ms step_avg:99.91ms +step:47/1670 train_time:4692ms step_avg:99.84ms +step:48/1670 train_time:4788ms step_avg:99.74ms +step:49/1670 train_time:4883ms step_avg:99.65ms +step:50/1670 train_time:4978ms step_avg:99.56ms +step:51/1670 train_time:5074ms step_avg:99.48ms +step:52/1670 train_time:5169ms step_avg:99.40ms +step:53/1670 train_time:5265ms step_avg:99.33ms +step:54/1670 train_time:5360ms step_avg:99.26ms +step:55/1670 train_time:5456ms step_avg:99.20ms +step:56/1670 train_time:5551ms step_avg:99.13ms +step:57/1670 train_time:5648ms step_avg:99.09ms +step:58/1670 train_time:5743ms step_avg:99.02ms +step:59/1670 train_time:5839ms step_avg:98.97ms +step:60/1670 train_time:5935ms step_avg:98.91ms +step:61/1670 train_time:6030ms step_avg:98.86ms +step:62/1670 train_time:6126ms step_avg:98.81ms +step:63/1670 train_time:6222ms step_avg:98.76ms +step:64/1670 train_time:6317ms step_avg:98.71ms +step:65/1670 train_time:6414ms step_avg:98.67ms +step:66/1670 train_time:6509ms step_avg:98.62ms +step:67/1670 train_time:6605ms step_avg:98.59ms +step:68/1670 train_time:6701ms step_avg:98.54ms +step:69/1670 train_time:6797ms step_avg:98.50ms +step:70/1670 train_time:6893ms step_avg:98.47ms +step:71/1670 train_time:6988ms step_avg:98.43ms +step:72/1670 train_time:7084ms step_avg:98.39ms +step:73/1670 train_time:7180ms step_avg:98.35ms +step:74/1670 train_time:7275ms step_avg:98.31ms +step:75/1670 train_time:7370ms step_avg:98.27ms +step:76/1670 train_time:7466ms step_avg:98.24ms +step:77/1670 train_time:7562ms step_avg:98.21ms +step:78/1670 train_time:7658ms step_avg:98.18ms +step:79/1670 train_time:7754ms step_avg:98.16ms +step:80/1670 train_time:7850ms step_avg:98.13ms +step:81/1670 train_time:7946ms step_avg:98.10ms +step:82/1670 train_time:8041ms step_avg:98.06ms +step:83/1670 train_time:8136ms step_avg:98.03ms +step:84/1670 train_time:8232ms step_avg:98.00ms +step:85/1670 train_time:8328ms step_avg:97.98ms +step:86/1670 train_time:8423ms step_avg:97.95ms +step:87/1670 train_time:8519ms step_avg:97.92ms +step:88/1670 train_time:8615ms step_avg:97.90ms +step:89/1670 train_time:8711ms step_avg:97.87ms +step:90/1670 train_time:8807ms step_avg:97.85ms +step:91/1670 train_time:8903ms step_avg:97.84ms +step:92/1670 train_time:8998ms step_avg:97.81ms +step:93/1670 train_time:9094ms step_avg:97.78ms +step:94/1670 train_time:9189ms step_avg:97.75ms +step:95/1670 train_time:9285ms step_avg:97.73ms +step:96/1670 train_time:9380ms step_avg:97.70ms +step:97/1670 train_time:9475ms step_avg:97.68ms +step:98/1670 train_time:9571ms step_avg:97.66ms +step:99/1670 train_time:9667ms step_avg:97.64ms +step:100/1670 train_time:9763ms step_avg:97.63ms +step:101/1670 train_time:9859ms step_avg:97.61ms +step:102/1670 train_time:9955ms step_avg:97.60ms +step:103/1670 train_time:10051ms step_avg:97.58ms +step:104/1670 train_time:10147ms step_avg:97.57ms +step:105/1670 train_time:10243ms step_avg:97.56ms +step:106/1670 train_time:10339ms step_avg:97.54ms +step:107/1670 train_time:10434ms step_avg:97.51ms +step:108/1670 train_time:10530ms step_avg:97.50ms +step:109/1670 train_time:10626ms step_avg:97.49ms +step:110/1670 train_time:10722ms step_avg:97.47ms +step:111/1670 train_time:10817ms step_avg:97.45ms +step:112/1670 train_time:10912ms step_avg:97.43ms +step:113/1670 train_time:11008ms step_avg:97.41ms +step:114/1670 train_time:11103ms step_avg:97.40ms +step:115/1670 train_time:11199ms step_avg:97.38ms +step:116/1670 train_time:11295ms step_avg:97.37ms +step:117/1670 train_time:11390ms step_avg:97.35ms +step:118/1670 train_time:11487ms step_avg:97.34ms +step:119/1670 train_time:11583ms step_avg:97.34ms +step:120/1670 train_time:11678ms step_avg:97.32ms +step:121/1670 train_time:11775ms step_avg:97.31ms +step:122/1670 train_time:11871ms step_avg:97.30ms +step:123/1670 train_time:11966ms step_avg:97.29ms +step:124/1670 train_time:12062ms step_avg:97.28ms +step:125/1670 train_time:12158ms step_avg:97.26ms +step:125/1670 val_loss:4.2958 train_time:12253ms step_avg:98.02ms +step:126/1670 train_time:12274ms step_avg:97.41ms +step:127/1670 train_time:12356ms step_avg:97.29ms +step:128/1670 train_time:12461ms step_avg:97.36ms +step:129/1670 train_time:12560ms step_avg:97.36ms +step:130/1670 train_time:12655ms step_avg:97.35ms +step:131/1670 train_time:12750ms step_avg:97.33ms +step:132/1670 train_time:12845ms step_avg:97.31ms +step:133/1670 train_time:12940ms step_avg:97.29ms +step:134/1670 train_time:13035ms step_avg:97.28ms +step:135/1670 train_time:13130ms step_avg:97.26ms +step:136/1670 train_time:13225ms step_avg:97.24ms +step:137/1670 train_time:13322ms step_avg:97.24ms +step:138/1670 train_time:13420ms step_avg:97.25ms +step:139/1670 train_time:13518ms step_avg:97.25ms +step:140/1670 train_time:13613ms step_avg:97.24ms +step:141/1670 train_time:13710ms step_avg:97.23ms +step:142/1670 train_time:13805ms step_avg:97.22ms +step:143/1670 train_time:13900ms step_avg:97.20ms +step:144/1670 train_time:13994ms step_avg:97.18ms +step:145/1670 train_time:14089ms step_avg:97.17ms +step:146/1670 train_time:14184ms step_avg:97.15ms +step:147/1670 train_time:14278ms step_avg:97.13ms +step:148/1670 train_time:14375ms step_avg:97.13ms +step:149/1670 train_time:14472ms step_avg:97.13ms +step:150/1670 train_time:14569ms step_avg:97.13ms +step:151/1670 train_time:14666ms step_avg:97.12ms +step:152/1670 train_time:14761ms step_avg:97.11ms +step:153/1670 train_time:14856ms step_avg:97.10ms +step:154/1670 train_time:14951ms step_avg:97.09ms +step:155/1670 train_time:15047ms step_avg:97.08ms +step:156/1670 train_time:15142ms step_avg:97.07ms +step:157/1670 train_time:15237ms step_avg:97.05ms +step:158/1670 train_time:15333ms step_avg:97.04ms +step:159/1670 train_time:15430ms step_avg:97.04ms +step:160/1670 train_time:15527ms step_avg:97.04ms +step:161/1670 train_time:15623ms step_avg:97.04ms +step:162/1670 train_time:15719ms step_avg:97.03ms +step:163/1670 train_time:15815ms step_avg:97.02ms +step:164/1670 train_time:15911ms step_avg:97.02ms +step:165/1670 train_time:16006ms step_avg:97.01ms +step:166/1670 train_time:16102ms step_avg:97.00ms +step:167/1670 train_time:16197ms step_avg:96.99ms +step:168/1670 train_time:16292ms step_avg:96.98ms +step:169/1670 train_time:16388ms step_avg:96.97ms +step:170/1670 train_time:16485ms step_avg:96.97ms +step:171/1670 train_time:16582ms step_avg:96.97ms +step:172/1670 train_time:16677ms step_avg:96.96ms +step:173/1670 train_time:16773ms step_avg:96.95ms +step:174/1670 train_time:16869ms step_avg:96.95ms +step:175/1670 train_time:16964ms step_avg:96.94ms +step:176/1670 train_time:17059ms step_avg:96.93ms +step:177/1670 train_time:17153ms step_avg:96.91ms +step:178/1670 train_time:17249ms step_avg:96.90ms +step:179/1670 train_time:17345ms step_avg:96.90ms +step:180/1670 train_time:17441ms step_avg:96.89ms +step:181/1670 train_time:17537ms step_avg:96.89ms +step:182/1670 train_time:17633ms step_avg:96.88ms +step:183/1670 train_time:17729ms step_avg:96.88ms +step:184/1670 train_time:17825ms step_avg:96.88ms +step:185/1670 train_time:17921ms step_avg:96.87ms +step:186/1670 train_time:18016ms step_avg:96.86ms +step:187/1670 train_time:18112ms step_avg:96.85ms +step:188/1670 train_time:18208ms step_avg:96.85ms +step:189/1670 train_time:18303ms step_avg:96.84ms +step:190/1670 train_time:18399ms step_avg:96.84ms +step:191/1670 train_time:18495ms step_avg:96.83ms +step:192/1670 train_time:18591ms step_avg:96.83ms +step:193/1670 train_time:18688ms step_avg:96.83ms +step:194/1670 train_time:18784ms step_avg:96.83ms +step:195/1670 train_time:18880ms step_avg:96.82ms +step:196/1670 train_time:18975ms step_avg:96.81ms +step:197/1670 train_time:19070ms step_avg:96.80ms +step:198/1670 train_time:19166ms step_avg:96.80ms +step:199/1670 train_time:19261ms step_avg:96.79ms +step:200/1670 train_time:19356ms step_avg:96.78ms +step:201/1670 train_time:19452ms step_avg:96.78ms +step:202/1670 train_time:19547ms step_avg:96.77ms +step:203/1670 train_time:19644ms step_avg:96.77ms +step:204/1670 train_time:19740ms step_avg:96.77ms +step:205/1670 train_time:19835ms step_avg:96.76ms +step:206/1670 train_time:19930ms step_avg:96.75ms +step:207/1670 train_time:20026ms step_avg:96.74ms +step:208/1670 train_time:20121ms step_avg:96.74ms +step:209/1670 train_time:20217ms step_avg:96.73ms +step:210/1670 train_time:20312ms step_avg:96.72ms +step:211/1670 train_time:20409ms step_avg:96.72ms +step:212/1670 train_time:20504ms step_avg:96.72ms +step:213/1670 train_time:20834ms step_avg:97.81ms +step:214/1670 train_time:20907ms step_avg:97.69ms +step:215/1670 train_time:21001ms step_avg:97.68ms +step:216/1670 train_time:21095ms step_avg:97.66ms +step:217/1670 train_time:21189ms step_avg:97.65ms +step:218/1670 train_time:21284ms step_avg:97.63ms +step:219/1670 train_time:21379ms step_avg:97.62ms +step:220/1670 train_time:21473ms step_avg:97.60ms +step:221/1670 train_time:21567ms step_avg:97.59ms +step:222/1670 train_time:21661ms step_avg:97.57ms +step:223/1670 train_time:21759ms step_avg:97.57ms +step:224/1670 train_time:21857ms step_avg:97.57ms +step:225/1670 train_time:21954ms step_avg:97.57ms +step:226/1670 train_time:22050ms step_avg:97.57ms +step:227/1670 train_time:22145ms step_avg:97.56ms +step:228/1670 train_time:22240ms step_avg:97.54ms +step:229/1670 train_time:22334ms step_avg:97.53ms +step:230/1670 train_time:22429ms step_avg:97.52ms +step:231/1670 train_time:22523ms step_avg:97.50ms +step:232/1670 train_time:22618ms step_avg:97.49ms +step:233/1670 train_time:22714ms step_avg:97.49ms +step:234/1670 train_time:22811ms step_avg:97.48ms +step:235/1670 train_time:22909ms step_avg:97.48ms +step:236/1670 train_time:23006ms step_avg:97.48ms +step:237/1670 train_time:23101ms step_avg:97.47ms +step:238/1670 train_time:23196ms step_avg:97.46ms +step:239/1670 train_time:23291ms step_avg:97.45ms +step:240/1670 train_time:23386ms step_avg:97.44ms +step:241/1670 train_time:23481ms step_avg:97.43ms +step:242/1670 train_time:23576ms step_avg:97.42ms +step:243/1670 train_time:23671ms step_avg:97.41ms +step:244/1670 train_time:23767ms step_avg:97.40ms +step:245/1670 train_time:23863ms step_avg:97.40ms +step:246/1670 train_time:23959ms step_avg:97.40ms +step:247/1670 train_time:24055ms step_avg:97.39ms +step:248/1670 train_time:24150ms step_avg:97.38ms +step:249/1670 train_time:24246ms step_avg:97.37ms +step:250/1670 train_time:24340ms step_avg:97.36ms +step:250/1670 val_loss:3.9752 train_time:24435ms step_avg:97.74ms +step:251/1670 train_time:24456ms step_avg:97.43ms +step:252/1670 train_time:24538ms step_avg:97.37ms +step:253/1670 train_time:24636ms step_avg:97.38ms +step:254/1670 train_time:24732ms step_avg:97.37ms +step:255/1670 train_time:24827ms step_avg:97.36ms +step:256/1670 train_time:24921ms step_avg:97.35ms +step:257/1670 train_time:25016ms step_avg:97.34ms +step:258/1670 train_time:25110ms step_avg:97.33ms +step:259/1670 train_time:25205ms step_avg:97.32ms +step:260/1670 train_time:25300ms step_avg:97.31ms +step:261/1670 train_time:25396ms step_avg:97.30ms +step:262/1670 train_time:25494ms step_avg:97.30ms +step:263/1670 train_time:25591ms step_avg:97.30ms +step:264/1670 train_time:25687ms step_avg:97.30ms +step:265/1670 train_time:25782ms step_avg:97.29ms +step:266/1670 train_time:25878ms step_avg:97.28ms +step:267/1670 train_time:25973ms step_avg:97.28ms +step:268/1670 train_time:26067ms step_avg:97.27ms +step:269/1670 train_time:26162ms step_avg:97.26ms +step:270/1670 train_time:26256ms step_avg:97.25ms +step:271/1670 train_time:26352ms step_avg:97.24ms +step:272/1670 train_time:26448ms step_avg:97.24ms +step:273/1670 train_time:26545ms step_avg:97.23ms +step:274/1670 train_time:26642ms step_avg:97.23ms +step:275/1670 train_time:26739ms step_avg:97.23ms +step:276/1670 train_time:26835ms step_avg:97.23ms +step:277/1670 train_time:26930ms step_avg:97.22ms +step:278/1670 train_time:27026ms step_avg:97.21ms +step:279/1670 train_time:27121ms step_avg:97.21ms +step:280/1670 train_time:27215ms step_avg:97.20ms +step:281/1670 train_time:27310ms step_avg:97.19ms +step:282/1670 train_time:27406ms step_avg:97.19ms +step:283/1670 train_time:27502ms step_avg:97.18ms +step:284/1670 train_time:27599ms step_avg:97.18ms +step:285/1670 train_time:27696ms step_avg:97.18ms +step:286/1670 train_time:27791ms step_avg:97.17ms +step:287/1670 train_time:27886ms step_avg:97.16ms +step:288/1670 train_time:27982ms step_avg:97.16ms +step:289/1670 train_time:28077ms step_avg:97.15ms +step:290/1670 train_time:28172ms step_avg:97.15ms +step:291/1670 train_time:28267ms step_avg:97.14ms +step:292/1670 train_time:28363ms step_avg:97.13ms +step:293/1670 train_time:28459ms step_avg:97.13ms +step:294/1670 train_time:28556ms step_avg:97.13ms +step:295/1670 train_time:28652ms step_avg:97.12ms +step:296/1670 train_time:28749ms step_avg:97.12ms +step:297/1670 train_time:28844ms step_avg:97.12ms +step:298/1670 train_time:28940ms step_avg:97.11ms +step:299/1670 train_time:29035ms step_avg:97.11ms +step:300/1670 train_time:29130ms step_avg:97.10ms +step:301/1670 train_time:29225ms step_avg:97.09ms +step:302/1670 train_time:29321ms step_avg:97.09ms +step:303/1670 train_time:29416ms step_avg:97.08ms +step:304/1670 train_time:29512ms step_avg:97.08ms +step:305/1670 train_time:29608ms step_avg:97.07ms +step:306/1670 train_time:29703ms step_avg:97.07ms +step:307/1670 train_time:29800ms step_avg:97.07ms +step:308/1670 train_time:29895ms step_avg:97.06ms +step:309/1670 train_time:29991ms step_avg:97.06ms +step:310/1670 train_time:30086ms step_avg:97.05ms +step:311/1670 train_time:30182ms step_avg:97.05ms +step:312/1670 train_time:30277ms step_avg:97.04ms +step:313/1670 train_time:30371ms step_avg:97.03ms +step:314/1670 train_time:30467ms step_avg:97.03ms +step:315/1670 train_time:30564ms step_avg:97.03ms +step:316/1670 train_time:30661ms step_avg:97.03ms +step:317/1670 train_time:30758ms step_avg:97.03ms +step:318/1670 train_time:30854ms step_avg:97.02ms +step:319/1670 train_time:30949ms step_avg:97.02ms +step:320/1670 train_time:31044ms step_avg:97.01ms +step:321/1670 train_time:31139ms step_avg:97.01ms +step:322/1670 train_time:31234ms step_avg:97.00ms +step:323/1670 train_time:31329ms step_avg:96.99ms +step:324/1670 train_time:31424ms step_avg:96.99ms +step:325/1670 train_time:31520ms step_avg:96.99ms +step:326/1670 train_time:31616ms step_avg:96.98ms +step:327/1670 train_time:31711ms step_avg:96.98ms +step:328/1670 train_time:31807ms step_avg:96.97ms +step:329/1670 train_time:31903ms step_avg:96.97ms +step:330/1670 train_time:31999ms step_avg:96.97ms +step:331/1670 train_time:32095ms step_avg:96.96ms +step:332/1670 train_time:32190ms step_avg:96.96ms +step:333/1670 train_time:32285ms step_avg:96.95ms +step:334/1670 train_time:32381ms step_avg:96.95ms +step:335/1670 train_time:32477ms step_avg:96.94ms +step:336/1670 train_time:32573ms step_avg:96.94ms +step:337/1670 train_time:32668ms step_avg:96.94ms +step:338/1670 train_time:32764ms step_avg:96.93ms +step:339/1670 train_time:32861ms step_avg:96.94ms +step:340/1670 train_time:32958ms step_avg:96.93ms +step:341/1670 train_time:33053ms step_avg:96.93ms +step:342/1670 train_time:33148ms step_avg:96.92ms +step:343/1670 train_time:33243ms step_avg:96.92ms +step:344/1670 train_time:33339ms step_avg:96.92ms +step:345/1670 train_time:33434ms step_avg:96.91ms +step:346/1670 train_time:33530ms step_avg:96.91ms +step:347/1670 train_time:33625ms step_avg:96.90ms +step:348/1670 train_time:33721ms step_avg:96.90ms +step:349/1670 train_time:33817ms step_avg:96.90ms +step:350/1670 train_time:33912ms step_avg:96.89ms +step:351/1670 train_time:34008ms step_avg:96.89ms +step:352/1670 train_time:34103ms step_avg:96.88ms +step:353/1670 train_time:34199ms step_avg:96.88ms +step:354/1670 train_time:34294ms step_avg:96.88ms +step:355/1670 train_time:34389ms step_avg:96.87ms +step:356/1670 train_time:34484ms step_avg:96.87ms +step:357/1670 train_time:34580ms step_avg:96.86ms +step:358/1670 train_time:34677ms step_avg:96.86ms +step:359/1670 train_time:34772ms step_avg:96.86ms +step:360/1670 train_time:34868ms step_avg:96.85ms +step:361/1670 train_time:34964ms step_avg:96.85ms +step:362/1670 train_time:35059ms step_avg:96.85ms +step:363/1670 train_time:35155ms step_avg:96.84ms +step:364/1670 train_time:35250ms step_avg:96.84ms +step:365/1670 train_time:35346ms step_avg:96.84ms +step:366/1670 train_time:35442ms step_avg:96.84ms +step:367/1670 train_time:35538ms step_avg:96.83ms +step:368/1670 train_time:35633ms step_avg:96.83ms +step:369/1670 train_time:35728ms step_avg:96.83ms +step:370/1670 train_time:35824ms step_avg:96.82ms +step:371/1670 train_time:35919ms step_avg:96.82ms +step:372/1670 train_time:36015ms step_avg:96.81ms +step:373/1670 train_time:36111ms step_avg:96.81ms +step:374/1670 train_time:36206ms step_avg:96.81ms +step:375/1670 train_time:36302ms step_avg:96.81ms +step:375/1670 val_loss:3.8227 train_time:36397ms step_avg:97.06ms +step:376/1670 train_time:36420ms step_avg:96.86ms +step:377/1670 train_time:36501ms step_avg:96.82ms +step:378/1670 train_time:36601ms step_avg:96.83ms +step:379/1670 train_time:36698ms step_avg:96.83ms +step:380/1670 train_time:36792ms step_avg:96.82ms +step:381/1670 train_time:36886ms step_avg:96.81ms +step:382/1670 train_time:36981ms step_avg:96.81ms +step:383/1670 train_time:37076ms step_avg:96.80ms +step:384/1670 train_time:37171ms step_avg:96.80ms +step:385/1670 train_time:37265ms step_avg:96.79ms +step:386/1670 train_time:37361ms step_avg:96.79ms +step:387/1670 train_time:37457ms step_avg:96.79ms +step:388/1670 train_time:37555ms step_avg:96.79ms +step:389/1670 train_time:37652ms step_avg:96.79ms +step:390/1670 train_time:37749ms step_avg:96.79ms +step:391/1670 train_time:37845ms step_avg:96.79ms +step:392/1670 train_time:37940ms step_avg:96.78ms +step:393/1670 train_time:38034ms step_avg:96.78ms +step:394/1670 train_time:38129ms step_avg:96.77ms +step:395/1670 train_time:38223ms step_avg:96.77ms +step:396/1670 train_time:38318ms step_avg:96.76ms +step:397/1670 train_time:38414ms step_avg:96.76ms +step:398/1670 train_time:38512ms step_avg:96.76ms +step:399/1670 train_time:38610ms step_avg:96.77ms +step:400/1670 train_time:38707ms step_avg:96.77ms +step:401/1670 train_time:38803ms step_avg:96.77ms +step:402/1670 train_time:38898ms step_avg:96.76ms +step:403/1670 train_time:38994ms step_avg:96.76ms +step:404/1670 train_time:39088ms step_avg:96.75ms +step:405/1670 train_time:39184ms step_avg:96.75ms +step:406/1670 train_time:39279ms step_avg:96.75ms +step:407/1670 train_time:39374ms step_avg:96.74ms +step:408/1670 train_time:39471ms step_avg:96.74ms +step:409/1670 train_time:39568ms step_avg:96.74ms +step:410/1670 train_time:39665ms step_avg:96.74ms +step:411/1670 train_time:39761ms step_avg:96.74ms +step:412/1670 train_time:39856ms step_avg:96.74ms +step:413/1670 train_time:39952ms step_avg:96.74ms +step:414/1670 train_time:40047ms step_avg:96.73ms +step:415/1670 train_time:40142ms step_avg:96.73ms +step:416/1670 train_time:40237ms step_avg:96.72ms +step:417/1670 train_time:40332ms step_avg:96.72ms +step:418/1670 train_time:40428ms step_avg:96.72ms +step:419/1670 train_time:40523ms step_avg:96.71ms +step:420/1670 train_time:40620ms step_avg:96.71ms +step:421/1670 train_time:40716ms step_avg:96.71ms +step:422/1670 train_time:40812ms step_avg:96.71ms +step:423/1670 train_time:40909ms step_avg:96.71ms +step:424/1670 train_time:41005ms step_avg:96.71ms +step:425/1670 train_time:41288ms step_avg:97.15ms +step:426/1670 train_time:41401ms step_avg:97.18ms +step:427/1670 train_time:41495ms step_avg:97.18ms +step:428/1670 train_time:41590ms step_avg:97.17ms +step:429/1670 train_time:41684ms step_avg:97.17ms +step:430/1670 train_time:41778ms step_avg:97.16ms +step:431/1670 train_time:41873ms step_avg:97.15ms +step:432/1670 train_time:41968ms step_avg:97.15ms +step:433/1670 train_time:42062ms step_avg:97.14ms +step:434/1670 train_time:42156ms step_avg:97.13ms +step:435/1670 train_time:42255ms step_avg:97.14ms +step:436/1670 train_time:42355ms step_avg:97.15ms +step:437/1670 train_time:42454ms step_avg:97.15ms +step:438/1670 train_time:42551ms step_avg:97.15ms +step:439/1670 train_time:42646ms step_avg:97.14ms +step:440/1670 train_time:42741ms step_avg:97.14ms +step:441/1670 train_time:42836ms step_avg:97.13ms +step:442/1670 train_time:42931ms step_avg:97.13ms +step:443/1670 train_time:43026ms step_avg:97.12ms +step:444/1670 train_time:43120ms step_avg:97.12ms +step:445/1670 train_time:43215ms step_avg:97.11ms +step:446/1670 train_time:43313ms step_avg:97.11ms +step:447/1670 train_time:43412ms step_avg:97.12ms +step:448/1670 train_time:43509ms step_avg:97.12ms +step:449/1670 train_time:43605ms step_avg:97.12ms +step:450/1670 train_time:43700ms step_avg:97.11ms +step:451/1670 train_time:43794ms step_avg:97.11ms +step:452/1670 train_time:43890ms step_avg:97.10ms +step:453/1670 train_time:43985ms step_avg:97.10ms +step:454/1670 train_time:44079ms step_avg:97.09ms +step:455/1670 train_time:44174ms step_avg:97.09ms +step:456/1670 train_time:44272ms step_avg:97.09ms +step:457/1670 train_time:44369ms step_avg:97.09ms +step:458/1670 train_time:44466ms step_avg:97.09ms +step:459/1670 train_time:44562ms step_avg:97.08ms +step:460/1670 train_time:44658ms step_avg:97.08ms +step:461/1670 train_time:44753ms step_avg:97.08ms +step:462/1670 train_time:44849ms step_avg:97.08ms +step:463/1670 train_time:44945ms step_avg:97.07ms +step:464/1670 train_time:45039ms step_avg:97.07ms +step:465/1670 train_time:45134ms step_avg:97.06ms +step:466/1670 train_time:45230ms step_avg:97.06ms +step:467/1670 train_time:45327ms step_avg:97.06ms +step:468/1670 train_time:45423ms step_avg:97.06ms +step:469/1670 train_time:45519ms step_avg:97.05ms +step:470/1670 train_time:45616ms step_avg:97.05ms +step:471/1670 train_time:45711ms step_avg:97.05ms +step:472/1670 train_time:45806ms step_avg:97.05ms +step:473/1670 train_time:45902ms step_avg:97.04ms +step:474/1670 train_time:45997ms step_avg:97.04ms +step:475/1670 train_time:46093ms step_avg:97.04ms +step:476/1670 train_time:46188ms step_avg:97.03ms +step:477/1670 train_time:46284ms step_avg:97.03ms +step:478/1670 train_time:46380ms step_avg:97.03ms +step:479/1670 train_time:46475ms step_avg:97.03ms +step:480/1670 train_time:46572ms step_avg:97.02ms +step:481/1670 train_time:46668ms step_avg:97.02ms +step:482/1670 train_time:46763ms step_avg:97.02ms +step:483/1670 train_time:46858ms step_avg:97.01ms +step:484/1670 train_time:46954ms step_avg:97.01ms +step:485/1670 train_time:47050ms step_avg:97.01ms +step:486/1670 train_time:47145ms step_avg:97.01ms +step:487/1670 train_time:47240ms step_avg:97.00ms +step:488/1670 train_time:47336ms step_avg:97.00ms +step:489/1670 train_time:47432ms step_avg:97.00ms +step:490/1670 train_time:47527ms step_avg:96.99ms +step:491/1670 train_time:47623ms step_avg:96.99ms +step:492/1670 train_time:47719ms step_avg:96.99ms +step:493/1670 train_time:47814ms step_avg:96.99ms +step:494/1670 train_time:47910ms step_avg:96.98ms +step:495/1670 train_time:48007ms step_avg:96.98ms +step:496/1670 train_time:48102ms step_avg:96.98ms +step:497/1670 train_time:48197ms step_avg:96.98ms +step:498/1670 train_time:48293ms step_avg:96.97ms +step:499/1670 train_time:48389ms step_avg:96.97ms +step:500/1670 train_time:48485ms step_avg:96.97ms +step:500/1670 val_loss:3.7166 train_time:48580ms step_avg:97.16ms +step:501/1670 train_time:48601ms step_avg:97.01ms +step:502/1670 train_time:48683ms step_avg:96.98ms +step:503/1670 train_time:48782ms step_avg:96.98ms +step:504/1670 train_time:48879ms step_avg:96.98ms +step:505/1670 train_time:48974ms step_avg:96.98ms +step:506/1670 train_time:49068ms step_avg:96.97ms +step:507/1670 train_time:49163ms step_avg:96.97ms +step:508/1670 train_time:49258ms step_avg:96.97ms +step:509/1670 train_time:49353ms step_avg:96.96ms +step:510/1670 train_time:49448ms step_avg:96.96ms +step:511/1670 train_time:49544ms step_avg:96.95ms +step:512/1670 train_time:49641ms step_avg:96.95ms +step:513/1670 train_time:49740ms step_avg:96.96ms +step:514/1670 train_time:49838ms step_avg:96.96ms +step:515/1670 train_time:49934ms step_avg:96.96ms +step:516/1670 train_time:50029ms step_avg:96.96ms +step:517/1670 train_time:50123ms step_avg:96.95ms +step:518/1670 train_time:50219ms step_avg:96.95ms +step:519/1670 train_time:50314ms step_avg:96.94ms +step:520/1670 train_time:50408ms step_avg:96.94ms +step:521/1670 train_time:50504ms step_avg:96.94ms +step:522/1670 train_time:50600ms step_avg:96.94ms +step:523/1670 train_time:50697ms step_avg:96.94ms +step:524/1670 train_time:50793ms step_avg:96.93ms +step:525/1670 train_time:50890ms step_avg:96.93ms +step:526/1670 train_time:50986ms step_avg:96.93ms +step:527/1670 train_time:51081ms step_avg:96.93ms +step:528/1670 train_time:51177ms step_avg:96.93ms +step:529/1670 train_time:51272ms step_avg:96.92ms +step:530/1670 train_time:51366ms step_avg:96.92ms +step:531/1670 train_time:51461ms step_avg:96.91ms +step:532/1670 train_time:51558ms step_avg:96.91ms +step:533/1670 train_time:51654ms step_avg:96.91ms +step:534/1670 train_time:51750ms step_avg:96.91ms +step:535/1670 train_time:51846ms step_avg:96.91ms +step:536/1670 train_time:51942ms step_avg:96.91ms +step:537/1670 train_time:52039ms step_avg:96.91ms +step:538/1670 train_time:52135ms step_avg:96.90ms +step:539/1670 train_time:52230ms step_avg:96.90ms +step:540/1670 train_time:52325ms step_avg:96.90ms +step:541/1670 train_time:52420ms step_avg:96.89ms +step:542/1670 train_time:52515ms step_avg:96.89ms +step:543/1670 train_time:52611ms step_avg:96.89ms +step:544/1670 train_time:52707ms step_avg:96.89ms +step:545/1670 train_time:52803ms step_avg:96.89ms +step:546/1670 train_time:52899ms step_avg:96.88ms +step:547/1670 train_time:52994ms step_avg:96.88ms +step:548/1670 train_time:53090ms step_avg:96.88ms +step:549/1670 train_time:53186ms step_avg:96.88ms +step:550/1670 train_time:53281ms step_avg:96.88ms +step:551/1670 train_time:53376ms step_avg:96.87ms +step:552/1670 train_time:53472ms step_avg:96.87ms +step:553/1670 train_time:53567ms step_avg:96.87ms +step:554/1670 train_time:53663ms step_avg:96.86ms +step:555/1670 train_time:53759ms step_avg:96.86ms +step:556/1670 train_time:53854ms step_avg:96.86ms +step:557/1670 train_time:53950ms step_avg:96.86ms +step:558/1670 train_time:54046ms step_avg:96.86ms +step:559/1670 train_time:54143ms step_avg:96.86ms +step:560/1670 train_time:54239ms step_avg:96.86ms +step:561/1670 train_time:54337ms step_avg:96.86ms +step:562/1670 train_time:54434ms step_avg:96.86ms +step:563/1670 train_time:54531ms step_avg:96.86ms +step:564/1670 train_time:54627ms step_avg:96.86ms +step:565/1670 train_time:54724ms step_avg:96.86ms +step:566/1670 train_time:54822ms step_avg:96.86ms +step:567/1670 train_time:54919ms step_avg:96.86ms +step:568/1670 train_time:55016ms step_avg:96.86ms +step:569/1670 train_time:55114ms step_avg:96.86ms +step:570/1670 train_time:55211ms step_avg:96.86ms +step:571/1670 train_time:55308ms step_avg:96.86ms +step:572/1670 train_time:55405ms step_avg:96.86ms +step:573/1670 train_time:55502ms step_avg:96.86ms +step:574/1670 train_time:55600ms step_avg:96.86ms +step:575/1670 train_time:55697ms step_avg:96.87ms +step:576/1670 train_time:55794ms step_avg:96.87ms +step:577/1670 train_time:55891ms step_avg:96.87ms +step:578/1670 train_time:55988ms step_avg:96.86ms +step:579/1670 train_time:56087ms step_avg:96.87ms +step:580/1670 train_time:56184ms step_avg:96.87ms +step:581/1670 train_time:56281ms step_avg:96.87ms +step:582/1670 train_time:56380ms step_avg:96.87ms +step:583/1670 train_time:56478ms step_avg:96.87ms +step:584/1670 train_time:56575ms step_avg:96.87ms +step:585/1670 train_time:56671ms step_avg:96.87ms +step:586/1670 train_time:56768ms step_avg:96.87ms +step:587/1670 train_time:56864ms step_avg:96.87ms +step:588/1670 train_time:56962ms step_avg:96.87ms +step:589/1670 train_time:57059ms step_avg:96.87ms +step:590/1670 train_time:57155ms step_avg:96.87ms +step:591/1670 train_time:57253ms step_avg:96.87ms +step:592/1670 train_time:57350ms step_avg:96.88ms +step:593/1670 train_time:57447ms step_avg:96.88ms +step:594/1670 train_time:57544ms step_avg:96.88ms +step:595/1670 train_time:57642ms step_avg:96.88ms +step:596/1670 train_time:57740ms step_avg:96.88ms +step:597/1670 train_time:57838ms step_avg:96.88ms +step:598/1670 train_time:57935ms step_avg:96.88ms +step:599/1670 train_time:58032ms step_avg:96.88ms +step:600/1670 train_time:58129ms step_avg:96.88ms +step:601/1670 train_time:58226ms step_avg:96.88ms +step:602/1670 train_time:58323ms step_avg:96.88ms +step:603/1670 train_time:58420ms step_avg:96.88ms +step:604/1670 train_time:58517ms step_avg:96.88ms +step:605/1670 train_time:58615ms step_avg:96.88ms +step:606/1670 train_time:58712ms step_avg:96.88ms +step:607/1670 train_time:58808ms step_avg:96.88ms +step:608/1670 train_time:58906ms step_avg:96.88ms +step:609/1670 train_time:59004ms step_avg:96.89ms +step:610/1670 train_time:59101ms step_avg:96.89ms +step:611/1670 train_time:59199ms step_avg:96.89ms +step:612/1670 train_time:59295ms step_avg:96.89ms +step:613/1670 train_time:59392ms step_avg:96.89ms +step:614/1670 train_time:59490ms step_avg:96.89ms +step:615/1670 train_time:59586ms step_avg:96.89ms +step:616/1670 train_time:59684ms step_avg:96.89ms +step:617/1670 train_time:59782ms step_avg:96.89ms +step:618/1670 train_time:59879ms step_avg:96.89ms +step:619/1670 train_time:59977ms step_avg:96.89ms +step:620/1670 train_time:60074ms step_avg:96.89ms +step:621/1670 train_time:60171ms step_avg:96.89ms +step:622/1670 train_time:60268ms step_avg:96.89ms +step:623/1670 train_time:60364ms step_avg:96.89ms +step:624/1670 train_time:60462ms step_avg:96.89ms +step:625/1670 train_time:60559ms step_avg:96.90ms +step:625/1670 val_loss:3.6170 train_time:60656ms step_avg:97.05ms +step:626/1670 train_time:60679ms step_avg:96.93ms +step:627/1670 train_time:60765ms step_avg:96.91ms +step:628/1670 train_time:60863ms step_avg:96.92ms +step:629/1670 train_time:60960ms step_avg:96.92ms +step:630/1670 train_time:61056ms step_avg:96.91ms +step:631/1670 train_time:61152ms step_avg:96.91ms +step:632/1670 train_time:61248ms step_avg:96.91ms +step:633/1670 train_time:61343ms step_avg:96.91ms +step:634/1670 train_time:61438ms step_avg:96.91ms +step:635/1670 train_time:61534ms step_avg:96.90ms +step:636/1670 train_time:61632ms step_avg:96.91ms +step:637/1670 train_time:61733ms step_avg:96.91ms +step:638/1670 train_time:61833ms step_avg:96.92ms +step:639/1670 train_time:62113ms step_avg:97.20ms +step:640/1670 train_time:62303ms step_avg:97.35ms +step:641/1670 train_time:62398ms step_avg:97.34ms +step:642/1670 train_time:62493ms step_avg:97.34ms +step:643/1670 train_time:62589ms step_avg:97.34ms +step:644/1670 train_time:62685ms step_avg:97.34ms +step:645/1670 train_time:62780ms step_avg:97.33ms +step:646/1670 train_time:62876ms step_avg:97.33ms +step:647/1670 train_time:62973ms step_avg:97.33ms +step:648/1670 train_time:63071ms step_avg:97.33ms +step:649/1670 train_time:63176ms step_avg:97.34ms +step:650/1670 train_time:63278ms step_avg:97.35ms +step:651/1670 train_time:63375ms step_avg:97.35ms +step:652/1670 train_time:63472ms step_avg:97.35ms +step:653/1670 train_time:63568ms step_avg:97.35ms +step:654/1670 train_time:63664ms step_avg:97.35ms +step:655/1670 train_time:63760ms step_avg:97.34ms +step:656/1670 train_time:63856ms step_avg:97.34ms +step:657/1670 train_time:63952ms step_avg:97.34ms +step:658/1670 train_time:64049ms step_avg:97.34ms +step:659/1670 train_time:64148ms step_avg:97.34ms +step:660/1670 train_time:64247ms step_avg:97.34ms +step:661/1670 train_time:64345ms step_avg:97.34ms +step:662/1670 train_time:64442ms step_avg:97.34ms +step:663/1670 train_time:64539ms step_avg:97.34ms +step:664/1670 train_time:64635ms step_avg:97.34ms +step:665/1670 train_time:64731ms step_avg:97.34ms +step:666/1670 train_time:64828ms step_avg:97.34ms +step:667/1670 train_time:64923ms step_avg:97.34ms +step:668/1670 train_time:65019ms step_avg:97.33ms +step:669/1670 train_time:65116ms step_avg:97.33ms +step:670/1670 train_time:65214ms step_avg:97.33ms +step:671/1670 train_time:65313ms step_avg:97.34ms +step:672/1670 train_time:65411ms step_avg:97.34ms +step:673/1670 train_time:65509ms step_avg:97.34ms +step:674/1670 train_time:65607ms step_avg:97.34ms +step:675/1670 train_time:65703ms step_avg:97.34ms +step:676/1670 train_time:65799ms step_avg:97.34ms +step:677/1670 train_time:65896ms step_avg:97.34ms +step:678/1670 train_time:65992ms step_avg:97.33ms +step:679/1670 train_time:66089ms step_avg:97.33ms +step:680/1670 train_time:66186ms step_avg:97.33ms +step:681/1670 train_time:66284ms step_avg:97.33ms +step:682/1670 train_time:66382ms step_avg:97.33ms +step:683/1670 train_time:66479ms step_avg:97.33ms +step:684/1670 train_time:66576ms step_avg:97.33ms +step:685/1670 train_time:66673ms step_avg:97.33ms +step:686/1670 train_time:66770ms step_avg:97.33ms +step:687/1670 train_time:66869ms step_avg:97.33ms +step:688/1670 train_time:66965ms step_avg:97.33ms +step:689/1670 train_time:67062ms step_avg:97.33ms +step:690/1670 train_time:67158ms step_avg:97.33ms +step:691/1670 train_time:67256ms step_avg:97.33ms +step:692/1670 train_time:67354ms step_avg:97.33ms +step:693/1670 train_time:67451ms step_avg:97.33ms +step:694/1670 train_time:67548ms step_avg:97.33ms +step:695/1670 train_time:67645ms step_avg:97.33ms +step:696/1670 train_time:67742ms step_avg:97.33ms +step:697/1670 train_time:67839ms step_avg:97.33ms +step:698/1670 train_time:67936ms step_avg:97.33ms +step:699/1670 train_time:68033ms step_avg:97.33ms +step:700/1670 train_time:68131ms step_avg:97.33ms +step:701/1670 train_time:68228ms step_avg:97.33ms +step:702/1670 train_time:68324ms step_avg:97.33ms +step:703/1670 train_time:68421ms step_avg:97.33ms +step:704/1670 train_time:68517ms step_avg:97.33ms +step:705/1670 train_time:68614ms step_avg:97.33ms +step:706/1670 train_time:68712ms step_avg:97.33ms +step:707/1670 train_time:68810ms step_avg:97.33ms +step:708/1670 train_time:68906ms step_avg:97.33ms +step:709/1670 train_time:69004ms step_avg:97.33ms +step:710/1670 train_time:69100ms step_avg:97.32ms +step:711/1670 train_time:69197ms step_avg:97.32ms +step:712/1670 train_time:69294ms step_avg:97.32ms +step:713/1670 train_time:69391ms step_avg:97.32ms +step:714/1670 train_time:69488ms step_avg:97.32ms +step:715/1670 train_time:69586ms step_avg:97.32ms +step:716/1670 train_time:69683ms step_avg:97.32ms +step:717/1670 train_time:69781ms step_avg:97.32ms +step:718/1670 train_time:69877ms step_avg:97.32ms +step:719/1670 train_time:69973ms step_avg:97.32ms +step:720/1670 train_time:70071ms step_avg:97.32ms +step:721/1670 train_time:70168ms step_avg:97.32ms +step:722/1670 train_time:70265ms step_avg:97.32ms +step:723/1670 train_time:70362ms step_avg:97.32ms +step:724/1670 train_time:70459ms step_avg:97.32ms +step:725/1670 train_time:70557ms step_avg:97.32ms +step:726/1670 train_time:70656ms step_avg:97.32ms +step:727/1670 train_time:70753ms step_avg:97.32ms +step:728/1670 train_time:70850ms step_avg:97.32ms +step:729/1670 train_time:70947ms step_avg:97.32ms +step:730/1670 train_time:71044ms step_avg:97.32ms +step:731/1670 train_time:71141ms step_avg:97.32ms +step:732/1670 train_time:71237ms step_avg:97.32ms +step:733/1670 train_time:71334ms step_avg:97.32ms +step:734/1670 train_time:71432ms step_avg:97.32ms +step:735/1670 train_time:71531ms step_avg:97.32ms +step:736/1670 train_time:71628ms step_avg:97.32ms +step:737/1670 train_time:71725ms step_avg:97.32ms +step:738/1670 train_time:71822ms step_avg:97.32ms +step:739/1670 train_time:71919ms step_avg:97.32ms +step:740/1670 train_time:72017ms step_avg:97.32ms +step:741/1670 train_time:72113ms step_avg:97.32ms +step:742/1670 train_time:72211ms step_avg:97.32ms +step:743/1670 train_time:72308ms step_avg:97.32ms +step:744/1670 train_time:72405ms step_avg:97.32ms +step:745/1670 train_time:72503ms step_avg:97.32ms +step:746/1670 train_time:72600ms step_avg:97.32ms +step:747/1670 train_time:72697ms step_avg:97.32ms +step:748/1670 train_time:72793ms step_avg:97.32ms +step:749/1670 train_time:72891ms step_avg:97.32ms +step:750/1670 train_time:72988ms step_avg:97.32ms +step:750/1670 val_loss:3.5615 train_time:73084ms step_avg:97.45ms +step:751/1670 train_time:73105ms step_avg:97.34ms +step:752/1670 train_time:73188ms step_avg:97.32ms +step:753/1670 train_time:73288ms step_avg:97.33ms +step:754/1670 train_time:73387ms step_avg:97.33ms +step:755/1670 train_time:73483ms step_avg:97.33ms +step:756/1670 train_time:73578ms step_avg:97.33ms +step:757/1670 train_time:73675ms step_avg:97.32ms +step:758/1670 train_time:73771ms step_avg:97.32ms +step:759/1670 train_time:73868ms step_avg:97.32ms +step:760/1670 train_time:73965ms step_avg:97.32ms +step:761/1670 train_time:74063ms step_avg:97.32ms +step:762/1670 train_time:74161ms step_avg:97.32ms +step:763/1670 train_time:74259ms step_avg:97.33ms +step:764/1670 train_time:74358ms step_avg:97.33ms +step:765/1670 train_time:74456ms step_avg:97.33ms +step:766/1670 train_time:74553ms step_avg:97.33ms +step:767/1670 train_time:74650ms step_avg:97.33ms +step:768/1670 train_time:74746ms step_avg:97.33ms +step:769/1670 train_time:74843ms step_avg:97.32ms +step:770/1670 train_time:74939ms step_avg:97.32ms +step:771/1670 train_time:75035ms step_avg:97.32ms +step:772/1670 train_time:75133ms step_avg:97.32ms +step:773/1670 train_time:75232ms step_avg:97.32ms +step:774/1670 train_time:75331ms step_avg:97.33ms +step:775/1670 train_time:75430ms step_avg:97.33ms +step:776/1670 train_time:75528ms step_avg:97.33ms +step:777/1670 train_time:75626ms step_avg:97.33ms +step:778/1670 train_time:75722ms step_avg:97.33ms +step:779/1670 train_time:75819ms step_avg:97.33ms +step:780/1670 train_time:75915ms step_avg:97.33ms +step:781/1670 train_time:76012ms step_avg:97.33ms +step:782/1670 train_time:76110ms step_avg:97.33ms +step:783/1670 train_time:76209ms step_avg:97.33ms +step:784/1670 train_time:76307ms step_avg:97.33ms +step:785/1670 train_time:76405ms step_avg:97.33ms +step:786/1670 train_time:76503ms step_avg:97.33ms +step:787/1670 train_time:76600ms step_avg:97.33ms +step:788/1670 train_time:76697ms step_avg:97.33ms +step:789/1670 train_time:76793ms step_avg:97.33ms +step:790/1670 train_time:76890ms step_avg:97.33ms +step:791/1670 train_time:76987ms step_avg:97.33ms +step:792/1670 train_time:77084ms step_avg:97.33ms +step:793/1670 train_time:77182ms step_avg:97.33ms +step:794/1670 train_time:77279ms step_avg:97.33ms +step:795/1670 train_time:77376ms step_avg:97.33ms +step:796/1670 train_time:77474ms step_avg:97.33ms +step:797/1670 train_time:77571ms step_avg:97.33ms +step:798/1670 train_time:77669ms step_avg:97.33ms +step:799/1670 train_time:77765ms step_avg:97.33ms +step:800/1670 train_time:77862ms step_avg:97.33ms +step:801/1670 train_time:77958ms step_avg:97.33ms +step:802/1670 train_time:78055ms step_avg:97.33ms +step:803/1670 train_time:78152ms step_avg:97.32ms +step:804/1670 train_time:78250ms step_avg:97.33ms +step:805/1670 train_time:78349ms step_avg:97.33ms +step:806/1670 train_time:78448ms step_avg:97.33ms +step:807/1670 train_time:78545ms step_avg:97.33ms +step:808/1670 train_time:78643ms step_avg:97.33ms +step:809/1670 train_time:78739ms step_avg:97.33ms +step:810/1670 train_time:78835ms step_avg:97.33ms +step:811/1670 train_time:78931ms step_avg:97.33ms +step:812/1670 train_time:79028ms step_avg:97.33ms +step:813/1670 train_time:79126ms step_avg:97.33ms +step:814/1670 train_time:79223ms step_avg:97.33ms +step:815/1670 train_time:79322ms step_avg:97.33ms +step:816/1670 train_time:79420ms step_avg:97.33ms +step:817/1670 train_time:79517ms step_avg:97.33ms +step:818/1670 train_time:79614ms step_avg:97.33ms +step:819/1670 train_time:79711ms step_avg:97.33ms +step:820/1670 train_time:79809ms step_avg:97.33ms +step:821/1670 train_time:79905ms step_avg:97.33ms +step:822/1670 train_time:80002ms step_avg:97.33ms +step:823/1670 train_time:80098ms step_avg:97.32ms +step:824/1670 train_time:80195ms step_avg:97.32ms +step:825/1670 train_time:80292ms step_avg:97.32ms +step:826/1670 train_time:80391ms step_avg:97.33ms +step:827/1670 train_time:80489ms step_avg:97.33ms +step:828/1670 train_time:80586ms step_avg:97.33ms +step:829/1670 train_time:80683ms step_avg:97.33ms +step:830/1670 train_time:80779ms step_avg:97.32ms +step:831/1670 train_time:80875ms step_avg:97.32ms +step:832/1670 train_time:80972ms step_avg:97.32ms +step:833/1670 train_time:81069ms step_avg:97.32ms +step:834/1670 train_time:81167ms step_avg:97.32ms +step:835/1670 train_time:81264ms step_avg:97.32ms +step:836/1670 train_time:81362ms step_avg:97.32ms +step:837/1670 train_time:81459ms step_avg:97.32ms +step:838/1670 train_time:81556ms step_avg:97.32ms +step:839/1670 train_time:81653ms step_avg:97.32ms +step:840/1670 train_time:81751ms step_avg:97.32ms +step:841/1670 train_time:81849ms step_avg:97.32ms +step:842/1670 train_time:81946ms step_avg:97.32ms +step:843/1670 train_time:82044ms step_avg:97.32ms +step:844/1670 train_time:82140ms step_avg:97.32ms +step:845/1670 train_time:82236ms step_avg:97.32ms +step:846/1670 train_time:82334ms step_avg:97.32ms +step:847/1670 train_time:82432ms step_avg:97.32ms +step:848/1670 train_time:82529ms step_avg:97.32ms +step:849/1670 train_time:82626ms step_avg:97.32ms +step:850/1670 train_time:82723ms step_avg:97.32ms +step:851/1670 train_time:83006ms step_avg:97.54ms +step:852/1670 train_time:83168ms step_avg:97.61ms +step:853/1670 train_time:83263ms step_avg:97.61ms +step:854/1670 train_time:83358ms step_avg:97.61ms +step:855/1670 train_time:83454ms step_avg:97.61ms +step:856/1670 train_time:83550ms step_avg:97.60ms +step:857/1670 train_time:83646ms step_avg:97.60ms +step:858/1670 train_time:83741ms step_avg:97.60ms +step:859/1670 train_time:83837ms step_avg:97.60ms +step:860/1670 train_time:83933ms step_avg:97.60ms +step:861/1670 train_time:84033ms step_avg:97.60ms +step:862/1670 train_time:84134ms step_avg:97.60ms +step:863/1670 train_time:84233ms step_avg:97.60ms +step:864/1670 train_time:84330ms step_avg:97.60ms +step:865/1670 train_time:84428ms step_avg:97.60ms +step:866/1670 train_time:84525ms step_avg:97.60ms +step:867/1670 train_time:84620ms step_avg:97.60ms +step:868/1670 train_time:84716ms step_avg:97.60ms +step:869/1670 train_time:84812ms step_avg:97.60ms +step:870/1670 train_time:84908ms step_avg:97.60ms +step:871/1670 train_time:85004ms step_avg:97.59ms +step:872/1670 train_time:85102ms step_avg:97.59ms +step:873/1670 train_time:85200ms step_avg:97.59ms +step:874/1670 train_time:85299ms step_avg:97.60ms +step:875/1670 train_time:85396ms step_avg:97.59ms +step:875/1670 val_loss:3.5212 train_time:85492ms step_avg:97.70ms +step:876/1670 train_time:85513ms step_avg:97.62ms +step:877/1670 train_time:85594ms step_avg:97.60ms +step:878/1670 train_time:85693ms step_avg:97.60ms +step:879/1670 train_time:85790ms step_avg:97.60ms +step:880/1670 train_time:85886ms step_avg:97.60ms +step:881/1670 train_time:85982ms step_avg:97.60ms +step:882/1670 train_time:86077ms step_avg:97.59ms +step:883/1670 train_time:86173ms step_avg:97.59ms +step:884/1670 train_time:86270ms step_avg:97.59ms +step:885/1670 train_time:86365ms step_avg:97.59ms +step:886/1670 train_time:86464ms step_avg:97.59ms +step:887/1670 train_time:86563ms step_avg:97.59ms +step:888/1670 train_time:86664ms step_avg:97.59ms +step:889/1670 train_time:86762ms step_avg:97.60ms +step:890/1670 train_time:86859ms step_avg:97.59ms +step:891/1670 train_time:86955ms step_avg:97.59ms +step:892/1670 train_time:87051ms step_avg:97.59ms +step:893/1670 train_time:87147ms step_avg:97.59ms +step:894/1670 train_time:87243ms step_avg:97.59ms +step:895/1670 train_time:87340ms step_avg:97.59ms +step:896/1670 train_time:87438ms step_avg:97.59ms +step:897/1670 train_time:87538ms step_avg:97.59ms +step:898/1670 train_time:87637ms step_avg:97.59ms +step:899/1670 train_time:87735ms step_avg:97.59ms +step:900/1670 train_time:87832ms step_avg:97.59ms +step:901/1670 train_time:87928ms step_avg:97.59ms +step:902/1670 train_time:88024ms step_avg:97.59ms +step:903/1670 train_time:88120ms step_avg:97.59ms +step:904/1670 train_time:88216ms step_avg:97.58ms +step:905/1670 train_time:88313ms step_avg:97.58ms +step:906/1670 train_time:88409ms step_avg:97.58ms +step:907/1670 train_time:88507ms step_avg:97.58ms +step:908/1670 train_time:88605ms step_avg:97.58ms +step:909/1670 train_time:88702ms step_avg:97.58ms +step:910/1670 train_time:88799ms step_avg:97.58ms +step:911/1670 train_time:88898ms step_avg:97.58ms +step:912/1670 train_time:88995ms step_avg:97.58ms +step:913/1670 train_time:89092ms step_avg:97.58ms +step:914/1670 train_time:89188ms step_avg:97.58ms +step:915/1670 train_time:89284ms step_avg:97.58ms +step:916/1670 train_time:89380ms step_avg:97.58ms +step:917/1670 train_time:89479ms step_avg:97.58ms +step:918/1670 train_time:89577ms step_avg:97.58ms +step:919/1670 train_time:89674ms step_avg:97.58ms +step:920/1670 train_time:89771ms step_avg:97.58ms +step:921/1670 train_time:89869ms step_avg:97.58ms +step:922/1670 train_time:89966ms step_avg:97.58ms +step:923/1670 train_time:90063ms step_avg:97.58ms +step:924/1670 train_time:90160ms step_avg:97.58ms +step:925/1670 train_time:90258ms step_avg:97.58ms +step:926/1670 train_time:90354ms step_avg:97.57ms +step:927/1670 train_time:90452ms step_avg:97.58ms +step:928/1670 train_time:90549ms step_avg:97.57ms +step:929/1670 train_time:90647ms step_avg:97.57ms +step:930/1670 train_time:90745ms step_avg:97.58ms +step:931/1670 train_time:90842ms step_avg:97.57ms +step:932/1670 train_time:90940ms step_avg:97.58ms +step:933/1670 train_time:91038ms step_avg:97.58ms +step:934/1670 train_time:91135ms step_avg:97.58ms +step:935/1670 train_time:91232ms step_avg:97.57ms +step:936/1670 train_time:91328ms step_avg:97.57ms +step:937/1670 train_time:91425ms step_avg:97.57ms +step:938/1670 train_time:91522ms step_avg:97.57ms +step:939/1670 train_time:91620ms step_avg:97.57ms +step:940/1670 train_time:91717ms step_avg:97.57ms +step:941/1670 train_time:91815ms step_avg:97.57ms +step:942/1670 train_time:91912ms step_avg:97.57ms +step:943/1670 train_time:92009ms step_avg:97.57ms +step:944/1670 train_time:92106ms step_avg:97.57ms +step:945/1670 train_time:92202ms step_avg:97.57ms +step:946/1670 train_time:92300ms step_avg:97.57ms +step:947/1670 train_time:92398ms step_avg:97.57ms +step:948/1670 train_time:92495ms step_avg:97.57ms +step:949/1670 train_time:92593ms step_avg:97.57ms +step:950/1670 train_time:92689ms step_avg:97.57ms +step:951/1670 train_time:92786ms step_avg:97.57ms +step:952/1670 train_time:92883ms step_avg:97.57ms +step:953/1670 train_time:92981ms step_avg:97.57ms +step:954/1670 train_time:93078ms step_avg:97.57ms +step:955/1670 train_time:93175ms step_avg:97.56ms +step:956/1670 train_time:93271ms step_avg:97.56ms +step:957/1670 train_time:93368ms step_avg:97.56ms +step:958/1670 train_time:93465ms step_avg:97.56ms +step:959/1670 train_time:93563ms step_avg:97.56ms +step:960/1670 train_time:93660ms step_avg:97.56ms +step:961/1670 train_time:93759ms step_avg:97.56ms +step:962/1670 train_time:93856ms step_avg:97.56ms +step:963/1670 train_time:93953ms step_avg:97.56ms +step:964/1670 train_time:94049ms step_avg:97.56ms +step:965/1670 train_time:94146ms step_avg:97.56ms +step:966/1670 train_time:94243ms step_avg:97.56ms +step:967/1670 train_time:94339ms step_avg:97.56ms +step:968/1670 train_time:94437ms step_avg:97.56ms +step:969/1670 train_time:94534ms step_avg:97.56ms +step:970/1670 train_time:94631ms step_avg:97.56ms +step:971/1670 train_time:94729ms step_avg:97.56ms +step:972/1670 train_time:94825ms step_avg:97.56ms +step:973/1670 train_time:94922ms step_avg:97.56ms +step:974/1670 train_time:95021ms step_avg:97.56ms +step:975/1670 train_time:95118ms step_avg:97.56ms +step:976/1670 train_time:95215ms step_avg:97.56ms +step:977/1670 train_time:95312ms step_avg:97.56ms +step:978/1670 train_time:95408ms step_avg:97.55ms +step:979/1670 train_time:95506ms step_avg:97.55ms +step:980/1670 train_time:95603ms step_avg:97.55ms +step:981/1670 train_time:95700ms step_avg:97.55ms +step:982/1670 train_time:95798ms step_avg:97.55ms +step:983/1670 train_time:95896ms step_avg:97.55ms +step:984/1670 train_time:95993ms step_avg:97.55ms +step:985/1670 train_time:96089ms step_avg:97.55ms +step:986/1670 train_time:96186ms step_avg:97.55ms +step:987/1670 train_time:96285ms step_avg:97.55ms +step:988/1670 train_time:96382ms step_avg:97.55ms +step:989/1670 train_time:96480ms step_avg:97.55ms +step:990/1670 train_time:96577ms step_avg:97.55ms +step:991/1670 train_time:96674ms step_avg:97.55ms +step:992/1670 train_time:96771ms step_avg:97.55ms +step:993/1670 train_time:96869ms step_avg:97.55ms +step:994/1670 train_time:96966ms step_avg:97.55ms +step:995/1670 train_time:97063ms step_avg:97.55ms +step:996/1670 train_time:97160ms step_avg:97.55ms +step:997/1670 train_time:97257ms step_avg:97.55ms +step:998/1670 train_time:97354ms step_avg:97.55ms +step:999/1670 train_time:97453ms step_avg:97.55ms +step:1000/1670 train_time:97550ms step_avg:97.55ms +step:1000/1670 val_loss:3.4775 train_time:97645ms step_avg:97.65ms +step:1001/1670 train_time:97667ms step_avg:97.57ms +step:1002/1670 train_time:97751ms step_avg:97.56ms +step:1003/1670 train_time:97852ms step_avg:97.56ms +step:1004/1670 train_time:97949ms step_avg:97.56ms +step:1005/1670 train_time:98046ms step_avg:97.56ms +step:1006/1670 train_time:98142ms step_avg:97.56ms +step:1007/1670 train_time:98238ms step_avg:97.56ms +step:1008/1670 train_time:98334ms step_avg:97.55ms +step:1009/1670 train_time:98430ms step_avg:97.55ms +step:1010/1670 train_time:98525ms step_avg:97.55ms +step:1011/1670 train_time:98623ms step_avg:97.55ms +step:1012/1670 train_time:98721ms step_avg:97.55ms +step:1013/1670 train_time:98820ms step_avg:97.55ms +step:1014/1670 train_time:98917ms step_avg:97.55ms +step:1015/1670 train_time:99015ms step_avg:97.55ms +step:1016/1670 train_time:99113ms step_avg:97.55ms +step:1017/1670 train_time:99209ms step_avg:97.55ms +step:1018/1670 train_time:99305ms step_avg:97.55ms +step:1019/1670 train_time:99402ms step_avg:97.55ms +step:1020/1670 train_time:99498ms step_avg:97.55ms +step:1021/1670 train_time:99595ms step_avg:97.55ms +step:1022/1670 train_time:99692ms step_avg:97.55ms +step:1023/1670 train_time:99790ms step_avg:97.55ms +step:1024/1670 train_time:99887ms step_avg:97.55ms +step:1025/1670 train_time:99985ms step_avg:97.55ms +step:1026/1670 train_time:100084ms step_avg:97.55ms +step:1027/1670 train_time:100181ms step_avg:97.55ms +step:1028/1670 train_time:100278ms step_avg:97.55ms +step:1029/1670 train_time:100375ms step_avg:97.55ms +step:1030/1670 train_time:100471ms step_avg:97.54ms +step:1031/1670 train_time:100568ms step_avg:97.54ms +step:1032/1670 train_time:100664ms step_avg:97.54ms +step:1033/1670 train_time:100761ms step_avg:97.54ms +step:1034/1670 train_time:100859ms step_avg:97.54ms +step:1035/1670 train_time:100957ms step_avg:97.54ms +step:1036/1670 train_time:101054ms step_avg:97.54ms +step:1037/1670 train_time:101152ms step_avg:97.54ms +step:1038/1670 train_time:101249ms step_avg:97.54ms +step:1039/1670 train_time:101345ms step_avg:97.54ms +step:1040/1670 train_time:101442ms step_avg:97.54ms +step:1041/1670 train_time:101539ms step_avg:97.54ms +step:1042/1670 train_time:101636ms step_avg:97.54ms +step:1043/1670 train_time:101734ms step_avg:97.54ms +step:1044/1670 train_time:101831ms step_avg:97.54ms +step:1045/1670 train_time:101927ms step_avg:97.54ms +step:1046/1670 train_time:102024ms step_avg:97.54ms +step:1047/1670 train_time:102122ms step_avg:97.54ms +step:1048/1670 train_time:102219ms step_avg:97.54ms +step:1049/1670 train_time:102317ms step_avg:97.54ms +step:1050/1670 train_time:102415ms step_avg:97.54ms +step:1051/1670 train_time:102512ms step_avg:97.54ms +step:1052/1670 train_time:102609ms step_avg:97.54ms +step:1053/1670 train_time:102705ms step_avg:97.54ms +step:1054/1670 train_time:102803ms step_avg:97.54ms +step:1055/1670 train_time:102900ms step_avg:97.54ms +step:1056/1670 train_time:102997ms step_avg:97.54ms +step:1057/1670 train_time:103095ms step_avg:97.54ms +step:1058/1670 train_time:103192ms step_avg:97.53ms +step:1059/1670 train_time:103288ms step_avg:97.53ms +step:1060/1670 train_time:103385ms step_avg:97.53ms +step:1061/1670 train_time:103482ms step_avg:97.53ms +step:1062/1670 train_time:103761ms step_avg:97.70ms +step:1063/1670 train_time:103848ms step_avg:97.69ms +step:1064/1670 train_time:103943ms step_avg:97.69ms +step:1065/1670 train_time:104039ms step_avg:97.69ms +step:1066/1670 train_time:104135ms step_avg:97.69ms +step:1067/1670 train_time:104231ms step_avg:97.69ms +step:1068/1670 train_time:104326ms step_avg:97.68ms +step:1069/1670 train_time:104423ms step_avg:97.68ms +step:1070/1670 train_time:104519ms step_avg:97.68ms +step:1071/1670 train_time:104615ms step_avg:97.68ms +step:1072/1670 train_time:104716ms step_avg:97.68ms +step:1073/1670 train_time:104817ms step_avg:97.69ms +step:1074/1670 train_time:104917ms step_avg:97.69ms +step:1075/1670 train_time:105014ms step_avg:97.69ms +step:1076/1670 train_time:105110ms step_avg:97.69ms +step:1077/1670 train_time:105206ms step_avg:97.68ms +step:1078/1670 train_time:105302ms step_avg:97.68ms +step:1079/1670 train_time:105398ms step_avg:97.68ms +step:1080/1670 train_time:105494ms step_avg:97.68ms +step:1081/1670 train_time:105590ms step_avg:97.68ms +step:1082/1670 train_time:105688ms step_avg:97.68ms +step:1083/1670 train_time:105786ms step_avg:97.68ms +step:1084/1670 train_time:105884ms step_avg:97.68ms +step:1085/1670 train_time:105983ms step_avg:97.68ms +step:1086/1670 train_time:106080ms step_avg:97.68ms +step:1087/1670 train_time:106178ms step_avg:97.68ms +step:1088/1670 train_time:106275ms step_avg:97.68ms +step:1089/1670 train_time:106371ms step_avg:97.68ms +step:1090/1670 train_time:106467ms step_avg:97.68ms +step:1091/1670 train_time:106563ms step_avg:97.67ms +step:1092/1670 train_time:106660ms step_avg:97.67ms +step:1093/1670 train_time:106758ms step_avg:97.67ms +step:1094/1670 train_time:106857ms step_avg:97.68ms +step:1095/1670 train_time:106957ms step_avg:97.68ms +step:1096/1670 train_time:107055ms step_avg:97.68ms +step:1097/1670 train_time:107152ms step_avg:97.68ms +step:1098/1670 train_time:107249ms step_avg:97.68ms +step:1099/1670 train_time:107345ms step_avg:97.68ms +step:1100/1670 train_time:107442ms step_avg:97.67ms +step:1101/1670 train_time:107539ms step_avg:97.67ms +step:1102/1670 train_time:107635ms step_avg:97.67ms +step:1103/1670 train_time:107732ms step_avg:97.67ms +step:1104/1670 train_time:107829ms step_avg:97.67ms +step:1105/1670 train_time:107927ms step_avg:97.67ms +step:1106/1670 train_time:108025ms step_avg:97.67ms +step:1107/1670 train_time:108122ms step_avg:97.67ms +step:1108/1670 train_time:108220ms step_avg:97.67ms +step:1109/1670 train_time:108317ms step_avg:97.67ms +step:1110/1670 train_time:108415ms step_avg:97.67ms +step:1111/1670 train_time:108512ms step_avg:97.67ms +step:1112/1670 train_time:108608ms step_avg:97.67ms +step:1113/1670 train_time:108704ms step_avg:97.67ms +step:1114/1670 train_time:108801ms step_avg:97.67ms +step:1115/1670 train_time:108898ms step_avg:97.67ms +step:1116/1670 train_time:108997ms step_avg:97.67ms +step:1117/1670 train_time:109095ms step_avg:97.67ms +step:1118/1670 train_time:109194ms step_avg:97.67ms +step:1119/1670 train_time:109291ms step_avg:97.67ms +step:1120/1670 train_time:109389ms step_avg:97.67ms +step:1121/1670 train_time:109486ms step_avg:97.67ms +step:1122/1670 train_time:109583ms step_avg:97.67ms +step:1123/1670 train_time:109680ms step_avg:97.67ms +step:1124/1670 train_time:109778ms step_avg:97.67ms +step:1125/1670 train_time:109875ms step_avg:97.67ms +step:1125/1670 val_loss:3.4241 train_time:109973ms step_avg:97.75ms +step:1126/1670 train_time:109994ms step_avg:97.69ms +step:1127/1670 train_time:110086ms step_avg:97.68ms +step:1128/1670 train_time:110188ms step_avg:97.68ms +step:1129/1670 train_time:110285ms step_avg:97.68ms +step:1130/1670 train_time:110382ms step_avg:97.68ms +step:1131/1670 train_time:110479ms step_avg:97.68ms +step:1132/1670 train_time:110576ms step_avg:97.68ms +step:1133/1670 train_time:110673ms step_avg:97.68ms +step:1134/1670 train_time:110769ms step_avg:97.68ms +step:1135/1670 train_time:110865ms step_avg:97.68ms +step:1136/1670 train_time:110965ms step_avg:97.68ms +step:1137/1670 train_time:111067ms step_avg:97.68ms +step:1138/1670 train_time:111166ms step_avg:97.69ms +step:1139/1670 train_time:111265ms step_avg:97.69ms +step:1140/1670 train_time:111363ms step_avg:97.69ms +step:1141/1670 train_time:111461ms step_avg:97.69ms +step:1142/1670 train_time:111558ms step_avg:97.69ms +step:1143/1670 train_time:111655ms step_avg:97.69ms +step:1144/1670 train_time:111752ms step_avg:97.68ms +step:1145/1670 train_time:111848ms step_avg:97.68ms +step:1146/1670 train_time:111945ms step_avg:97.68ms +step:1147/1670 train_time:112043ms step_avg:97.68ms +step:1148/1670 train_time:112142ms step_avg:97.68ms +step:1149/1670 train_time:112240ms step_avg:97.69ms +step:1150/1670 train_time:112339ms step_avg:97.69ms +step:1151/1670 train_time:112437ms step_avg:97.69ms +step:1152/1670 train_time:112534ms step_avg:97.69ms +step:1153/1670 train_time:112631ms step_avg:97.69ms +step:1154/1670 train_time:112728ms step_avg:97.68ms +step:1155/1670 train_time:112825ms step_avg:97.68ms +step:1156/1670 train_time:112923ms step_avg:97.68ms +step:1157/1670 train_time:113020ms step_avg:97.68ms +step:1158/1670 train_time:113119ms step_avg:97.68ms +step:1159/1670 train_time:113217ms step_avg:97.68ms +step:1160/1670 train_time:113315ms step_avg:97.68ms +step:1161/1670 train_time:113413ms step_avg:97.69ms +step:1162/1670 train_time:113511ms step_avg:97.69ms +step:1163/1670 train_time:113609ms step_avg:97.69ms +step:1164/1670 train_time:113705ms step_avg:97.68ms +step:1165/1670 train_time:113802ms step_avg:97.68ms +step:1166/1670 train_time:113900ms step_avg:97.68ms +step:1167/1670 train_time:113998ms step_avg:97.68ms +step:1168/1670 train_time:114095ms step_avg:97.68ms +step:1169/1670 train_time:114194ms step_avg:97.69ms +step:1170/1670 train_time:114292ms step_avg:97.69ms +step:1171/1670 train_time:114392ms step_avg:97.69ms +step:1172/1670 train_time:114492ms step_avg:97.69ms +step:1173/1670 train_time:114590ms step_avg:97.69ms +step:1174/1670 train_time:114688ms step_avg:97.69ms +step:1175/1670 train_time:114786ms step_avg:97.69ms +step:1176/1670 train_time:114883ms step_avg:97.69ms +step:1177/1670 train_time:114980ms step_avg:97.69ms +step:1178/1670 train_time:115077ms step_avg:97.69ms +step:1179/1670 train_time:115175ms step_avg:97.69ms +step:1180/1670 train_time:115273ms step_avg:97.69ms +step:1181/1670 train_time:115371ms step_avg:97.69ms +step:1182/1670 train_time:115470ms step_avg:97.69ms +step:1183/1670 train_time:115569ms step_avg:97.69ms +step:1184/1670 train_time:115668ms step_avg:97.69ms +step:1185/1670 train_time:115766ms step_avg:97.69ms +step:1186/1670 train_time:115865ms step_avg:97.69ms +step:1187/1670 train_time:115962ms step_avg:97.69ms +step:1188/1670 train_time:116059ms step_avg:97.69ms +step:1189/1670 train_time:116157ms step_avg:97.69ms +step:1190/1670 train_time:116254ms step_avg:97.69ms +step:1191/1670 train_time:116352ms step_avg:97.69ms +step:1192/1670 train_time:116451ms step_avg:97.69ms +step:1193/1670 train_time:116549ms step_avg:97.69ms +step:1194/1670 train_time:116647ms step_avg:97.69ms +step:1195/1670 train_time:116744ms step_avg:97.69ms +step:1196/1670 train_time:116842ms step_avg:97.69ms +step:1197/1670 train_time:116939ms step_avg:97.69ms +step:1198/1670 train_time:117036ms step_avg:97.69ms +step:1199/1670 train_time:117134ms step_avg:97.69ms +step:1200/1670 train_time:117231ms step_avg:97.69ms +step:1201/1670 train_time:117329ms step_avg:97.69ms +step:1202/1670 train_time:117429ms step_avg:97.69ms +step:1203/1670 train_time:117526ms step_avg:97.69ms +step:1204/1670 train_time:117623ms step_avg:97.69ms +step:1205/1670 train_time:117721ms step_avg:97.69ms +step:1206/1670 train_time:117819ms step_avg:97.69ms +step:1207/1670 train_time:117917ms step_avg:97.69ms +step:1208/1670 train_time:118015ms step_avg:97.69ms +step:1209/1670 train_time:118112ms step_avg:97.69ms +step:1210/1670 train_time:118210ms step_avg:97.69ms +step:1211/1670 train_time:118307ms step_avg:97.69ms +step:1212/1670 train_time:118405ms step_avg:97.69ms +step:1213/1670 train_time:118503ms step_avg:97.69ms +step:1214/1670 train_time:118600ms step_avg:97.69ms +step:1215/1670 train_time:118697ms step_avg:97.69ms +step:1216/1670 train_time:118795ms step_avg:97.69ms +step:1217/1670 train_time:118893ms step_avg:97.69ms +step:1218/1670 train_time:118992ms step_avg:97.69ms +step:1219/1670 train_time:119089ms step_avg:97.69ms +step:1220/1670 train_time:119186ms step_avg:97.69ms +step:1221/1670 train_time:119284ms step_avg:97.69ms +step:1222/1670 train_time:119382ms step_avg:97.69ms +step:1223/1670 train_time:119479ms step_avg:97.69ms +step:1224/1670 train_time:119576ms step_avg:97.69ms +step:1225/1670 train_time:119674ms step_avg:97.69ms +step:1226/1670 train_time:119772ms step_avg:97.69ms +step:1227/1670 train_time:119871ms step_avg:97.69ms +step:1228/1670 train_time:119972ms step_avg:97.70ms +step:1229/1670 train_time:120071ms step_avg:97.70ms +step:1230/1670 train_time:120169ms step_avg:97.70ms +step:1231/1670 train_time:120267ms step_avg:97.70ms +step:1232/1670 train_time:120365ms step_avg:97.70ms +step:1233/1670 train_time:120463ms step_avg:97.70ms +step:1234/1670 train_time:120560ms step_avg:97.70ms +step:1235/1670 train_time:120657ms step_avg:97.70ms +step:1236/1670 train_time:120755ms step_avg:97.70ms +step:1237/1670 train_time:120853ms step_avg:97.70ms +step:1238/1670 train_time:120949ms step_avg:97.70ms +step:1239/1670 train_time:121046ms step_avg:97.70ms +step:1240/1670 train_time:121144ms step_avg:97.70ms +step:1241/1670 train_time:121241ms step_avg:97.70ms +step:1242/1670 train_time:121339ms step_avg:97.70ms +step:1243/1670 train_time:121437ms step_avg:97.70ms +step:1244/1670 train_time:121535ms step_avg:97.70ms +step:1245/1670 train_time:121633ms step_avg:97.70ms +step:1246/1670 train_time:121733ms step_avg:97.70ms +step:1247/1670 train_time:121831ms step_avg:97.70ms +step:1248/1670 train_time:121929ms step_avg:97.70ms +step:1249/1670 train_time:122028ms step_avg:97.70ms +step:1250/1670 train_time:122126ms step_avg:97.70ms +step:1250/1670 val_loss:3.3814 train_time:122222ms step_avg:97.78ms +step:1251/1670 train_time:122244ms step_avg:97.72ms +step:1252/1670 train_time:122329ms step_avg:97.71ms +step:1253/1670 train_time:122428ms step_avg:97.71ms +step:1254/1670 train_time:122526ms step_avg:97.71ms +step:1255/1670 train_time:122623ms step_avg:97.71ms +step:1256/1670 train_time:122720ms step_avg:97.71ms +step:1257/1670 train_time:122816ms step_avg:97.71ms +step:1258/1670 train_time:122912ms step_avg:97.70ms +step:1259/1670 train_time:123008ms step_avg:97.70ms +step:1260/1670 train_time:123104ms step_avg:97.70ms +step:1261/1670 train_time:123202ms step_avg:97.70ms +step:1262/1670 train_time:123303ms step_avg:97.70ms +step:1263/1670 train_time:123402ms step_avg:97.71ms +step:1264/1670 train_time:123501ms step_avg:97.71ms +step:1265/1670 train_time:123600ms step_avg:97.71ms +step:1266/1670 train_time:123697ms step_avg:97.71ms +step:1267/1670 train_time:123795ms step_avg:97.71ms +step:1268/1670 train_time:123892ms step_avg:97.71ms +step:1269/1670 train_time:123989ms step_avg:97.71ms +step:1270/1670 train_time:124085ms step_avg:97.71ms +step:1271/1670 train_time:124182ms step_avg:97.70ms +step:1272/1670 train_time:124281ms step_avg:97.71ms +step:1273/1670 train_time:124382ms step_avg:97.71ms +step:1274/1670 train_time:124776ms step_avg:97.94ms +step:1275/1670 train_time:124877ms step_avg:97.94ms +step:1276/1670 train_time:124972ms step_avg:97.94ms +step:1277/1670 train_time:125069ms step_avg:97.94ms +step:1278/1670 train_time:125165ms step_avg:97.94ms +step:1279/1670 train_time:125262ms step_avg:97.94ms +step:1280/1670 train_time:125359ms step_avg:97.94ms +step:1281/1670 train_time:125456ms step_avg:97.94ms +step:1282/1670 train_time:125552ms step_avg:97.93ms +step:1283/1670 train_time:125649ms step_avg:97.93ms +step:1284/1670 train_time:125746ms step_avg:97.93ms +step:1285/1670 train_time:125847ms step_avg:97.94ms +step:1286/1670 train_time:125946ms step_avg:97.94ms +step:1287/1670 train_time:126044ms step_avg:97.94ms +step:1288/1670 train_time:126142ms step_avg:97.94ms +step:1289/1670 train_time:126240ms step_avg:97.94ms +step:1290/1670 train_time:126337ms step_avg:97.94ms +step:1291/1670 train_time:126434ms step_avg:97.93ms +step:1292/1670 train_time:126530ms step_avg:97.93ms +step:1293/1670 train_time:126627ms step_avg:97.93ms +step:1294/1670 train_time:126724ms step_avg:97.93ms +step:1295/1670 train_time:126823ms step_avg:97.93ms +step:1296/1670 train_time:126922ms step_avg:97.93ms +step:1297/1670 train_time:127021ms step_avg:97.93ms +step:1298/1670 train_time:127119ms step_avg:97.93ms +step:1299/1670 train_time:127217ms step_avg:97.93ms +step:1300/1670 train_time:127314ms step_avg:97.93ms +step:1301/1670 train_time:127411ms step_avg:97.93ms +step:1302/1670 train_time:127509ms step_avg:97.93ms +step:1303/1670 train_time:127605ms step_avg:97.93ms +step:1304/1670 train_time:127703ms step_avg:97.93ms +step:1305/1670 train_time:127800ms step_avg:97.93ms +step:1306/1670 train_time:127898ms step_avg:97.93ms +step:1307/1670 train_time:127997ms step_avg:97.93ms +step:1308/1670 train_time:128097ms step_avg:97.93ms +step:1309/1670 train_time:128195ms step_avg:97.93ms +step:1310/1670 train_time:128291ms step_avg:97.93ms +step:1311/1670 train_time:128388ms step_avg:97.93ms +step:1312/1670 train_time:128485ms step_avg:97.93ms +step:1313/1670 train_time:128582ms step_avg:97.93ms +step:1314/1670 train_time:128681ms step_avg:97.93ms +step:1315/1670 train_time:128778ms step_avg:97.93ms +step:1316/1670 train_time:128877ms step_avg:97.93ms +step:1317/1670 train_time:128975ms step_avg:97.93ms +step:1318/1670 train_time:129073ms step_avg:97.93ms +step:1319/1670 train_time:129171ms step_avg:97.93ms +step:1320/1670 train_time:129268ms step_avg:97.93ms +step:1321/1670 train_time:129366ms step_avg:97.93ms +step:1322/1670 train_time:129464ms step_avg:97.93ms +step:1323/1670 train_time:129562ms step_avg:97.93ms +step:1324/1670 train_time:129660ms step_avg:97.93ms +step:1325/1670 train_time:129759ms step_avg:97.93ms +step:1326/1670 train_time:129856ms step_avg:97.93ms +step:1327/1670 train_time:129954ms step_avg:97.93ms +step:1328/1670 train_time:130051ms step_avg:97.93ms +step:1329/1670 train_time:130149ms step_avg:97.93ms +step:1330/1670 train_time:130247ms step_avg:97.93ms +step:1331/1670 train_time:130345ms step_avg:97.93ms +step:1332/1670 train_time:130442ms step_avg:97.93ms +step:1333/1670 train_time:130540ms step_avg:97.93ms +step:1334/1670 train_time:130637ms step_avg:97.93ms +step:1335/1670 train_time:130736ms step_avg:97.93ms +step:1336/1670 train_time:130833ms step_avg:97.93ms +step:1337/1670 train_time:130931ms step_avg:97.93ms +step:1338/1670 train_time:131029ms step_avg:97.93ms +step:1339/1670 train_time:131129ms step_avg:97.93ms +step:1340/1670 train_time:131227ms step_avg:97.93ms +step:1341/1670 train_time:131324ms step_avg:97.93ms +step:1342/1670 train_time:131421ms step_avg:97.93ms +step:1343/1670 train_time:131518ms step_avg:97.93ms +step:1344/1670 train_time:131616ms step_avg:97.93ms +step:1345/1670 train_time:131714ms step_avg:97.93ms +step:1346/1670 train_time:131812ms step_avg:97.93ms +step:1347/1670 train_time:131909ms step_avg:97.93ms +step:1348/1670 train_time:132006ms step_avg:97.93ms +step:1349/1670 train_time:132106ms step_avg:97.93ms +step:1350/1670 train_time:132204ms step_avg:97.93ms +step:1351/1670 train_time:132301ms step_avg:97.93ms +step:1352/1670 train_time:132399ms step_avg:97.93ms +step:1353/1670 train_time:132496ms step_avg:97.93ms +step:1354/1670 train_time:132595ms step_avg:97.93ms +step:1355/1670 train_time:132693ms step_avg:97.93ms +step:1356/1670 train_time:132790ms step_avg:97.93ms +step:1357/1670 train_time:132887ms step_avg:97.93ms +step:1358/1670 train_time:132985ms step_avg:97.93ms +step:1359/1670 train_time:133082ms step_avg:97.93ms +step:1360/1670 train_time:133181ms step_avg:97.93ms +step:1361/1670 train_time:133278ms step_avg:97.93ms +step:1362/1670 train_time:133377ms step_avg:97.93ms +step:1363/1670 train_time:133475ms step_avg:97.93ms +step:1364/1670 train_time:133573ms step_avg:97.93ms +step:1365/1670 train_time:133671ms step_avg:97.93ms +step:1366/1670 train_time:133768ms step_avg:97.93ms +step:1367/1670 train_time:133866ms step_avg:97.93ms +step:1368/1670 train_time:133963ms step_avg:97.93ms +step:1369/1670 train_time:134062ms step_avg:97.93ms +step:1370/1670 train_time:134161ms step_avg:97.93ms +step:1371/1670 train_time:134259ms step_avg:97.93ms +step:1372/1670 train_time:134357ms step_avg:97.93ms +step:1373/1670 train_time:134455ms step_avg:97.93ms +step:1374/1670 train_time:134553ms step_avg:97.93ms +step:1375/1670 train_time:134651ms step_avg:97.93ms +step:1375/1670 val_loss:3.3443 train_time:134749ms step_avg:98.00ms +step:1376/1670 train_time:134770ms step_avg:97.94ms +step:1377/1670 train_time:134855ms step_avg:97.93ms +step:1378/1670 train_time:134959ms step_avg:97.94ms +step:1379/1670 train_time:135057ms step_avg:97.94ms +step:1380/1670 train_time:135154ms step_avg:97.94ms +step:1381/1670 train_time:135250ms step_avg:97.94ms +step:1382/1670 train_time:135347ms step_avg:97.94ms +step:1383/1670 train_time:135444ms step_avg:97.94ms +step:1384/1670 train_time:135541ms step_avg:97.93ms +step:1385/1670 train_time:135638ms step_avg:97.93ms +step:1386/1670 train_time:135737ms step_avg:97.93ms +step:1387/1670 train_time:135840ms step_avg:97.94ms +step:1388/1670 train_time:135941ms step_avg:97.94ms +step:1389/1670 train_time:136042ms step_avg:97.94ms +step:1390/1670 train_time:136140ms step_avg:97.94ms +step:1391/1670 train_time:136238ms step_avg:97.94ms +step:1392/1670 train_time:136335ms step_avg:97.94ms +step:1393/1670 train_time:136432ms step_avg:97.94ms +step:1394/1670 train_time:136529ms step_avg:97.94ms +step:1395/1670 train_time:136625ms step_avg:97.94ms +step:1396/1670 train_time:136723ms step_avg:97.94ms +step:1397/1670 train_time:136822ms step_avg:97.94ms +step:1398/1670 train_time:136922ms step_avg:97.94ms +step:1399/1670 train_time:137023ms step_avg:97.94ms +step:1400/1670 train_time:137122ms step_avg:97.94ms +step:1401/1670 train_time:137222ms step_avg:97.95ms +step:1402/1670 train_time:137320ms step_avg:97.95ms +step:1403/1670 train_time:137418ms step_avg:97.95ms +step:1404/1670 train_time:137515ms step_avg:97.94ms +step:1405/1670 train_time:137611ms step_avg:97.94ms +step:1406/1670 train_time:137708ms step_avg:97.94ms +step:1407/1670 train_time:137806ms step_avg:97.94ms +step:1408/1670 train_time:137905ms step_avg:97.94ms +step:1409/1670 train_time:138005ms step_avg:97.95ms +step:1410/1670 train_time:138106ms step_avg:97.95ms +step:1411/1670 train_time:138204ms step_avg:97.95ms +step:1412/1670 train_time:138302ms step_avg:97.95ms +step:1413/1670 train_time:138400ms step_avg:97.95ms +step:1414/1670 train_time:138497ms step_avg:97.95ms +step:1415/1670 train_time:138594ms step_avg:97.95ms +step:1416/1670 train_time:138691ms step_avg:97.95ms +step:1417/1670 train_time:138789ms step_avg:97.95ms +step:1418/1670 train_time:138886ms step_avg:97.95ms +step:1419/1670 train_time:138985ms step_avg:97.95ms +step:1420/1670 train_time:139085ms step_avg:97.95ms +step:1421/1670 train_time:139183ms step_avg:97.95ms +step:1422/1670 train_time:139282ms step_avg:97.95ms +step:1423/1670 train_time:139380ms step_avg:97.95ms +step:1424/1670 train_time:139477ms step_avg:97.95ms +step:1425/1670 train_time:139575ms step_avg:97.95ms +step:1426/1670 train_time:139673ms step_avg:97.95ms +step:1427/1670 train_time:139771ms step_avg:97.95ms +step:1428/1670 train_time:139868ms step_avg:97.95ms +step:1429/1670 train_time:139966ms step_avg:97.95ms +step:1430/1670 train_time:140065ms step_avg:97.95ms +step:1431/1670 train_time:140164ms step_avg:97.95ms +step:1432/1670 train_time:140262ms step_avg:97.95ms +step:1433/1670 train_time:140360ms step_avg:97.95ms +step:1434/1670 train_time:140458ms step_avg:97.95ms +step:1435/1670 train_time:140555ms step_avg:97.95ms +step:1436/1670 train_time:140652ms step_avg:97.95ms +step:1437/1670 train_time:140749ms step_avg:97.95ms +step:1438/1670 train_time:140846ms step_avg:97.95ms +step:1439/1670 train_time:140944ms step_avg:97.95ms +step:1440/1670 train_time:141042ms step_avg:97.95ms +step:1441/1670 train_time:141141ms step_avg:97.95ms +step:1442/1670 train_time:141240ms step_avg:97.95ms +step:1443/1670 train_time:141336ms step_avg:97.95ms +step:1444/1670 train_time:141434ms step_avg:97.95ms +step:1445/1670 train_time:141532ms step_avg:97.95ms +step:1446/1670 train_time:141630ms step_avg:97.95ms +step:1447/1670 train_time:141727ms step_avg:97.95ms +step:1448/1670 train_time:141824ms step_avg:97.94ms +step:1449/1670 train_time:141922ms step_avg:97.94ms +step:1450/1670 train_time:142020ms step_avg:97.94ms +step:1451/1670 train_time:142117ms step_avg:97.94ms +step:1452/1670 train_time:142215ms step_avg:97.94ms +step:1453/1670 train_time:142312ms step_avg:97.94ms +step:1454/1670 train_time:142410ms step_avg:97.94ms +step:1455/1670 train_time:142509ms step_avg:97.94ms +step:1456/1670 train_time:142607ms step_avg:97.94ms +step:1457/1670 train_time:142705ms step_avg:97.94ms +step:1458/1670 train_time:142803ms step_avg:97.94ms +step:1459/1670 train_time:142902ms step_avg:97.95ms +step:1460/1670 train_time:142999ms step_avg:97.94ms +step:1461/1670 train_time:143097ms step_avg:97.94ms +step:1462/1670 train_time:143195ms step_avg:97.94ms +step:1463/1670 train_time:143292ms step_avg:97.94ms +step:1464/1670 train_time:143390ms step_avg:97.94ms +step:1465/1670 train_time:143487ms step_avg:97.94ms +step:1466/1670 train_time:143585ms step_avg:97.94ms +step:1467/1670 train_time:143684ms step_avg:97.94ms +step:1468/1670 train_time:143783ms step_avg:97.94ms +step:1469/1670 train_time:143881ms step_avg:97.94ms +step:1470/1670 train_time:143978ms step_avg:97.94ms +step:1471/1670 train_time:144076ms step_avg:97.94ms +step:1472/1670 train_time:144174ms step_avg:97.94ms +step:1473/1670 train_time:144272ms step_avg:97.94ms +step:1474/1670 train_time:144370ms step_avg:97.94ms +step:1475/1670 train_time:144467ms step_avg:97.94ms +step:1476/1670 train_time:144564ms step_avg:97.94ms +step:1477/1670 train_time:144663ms step_avg:97.94ms +step:1478/1670 train_time:144761ms step_avg:97.94ms +step:1479/1670 train_time:144859ms step_avg:97.94ms +step:1480/1670 train_time:144957ms step_avg:97.94ms +step:1481/1670 train_time:145054ms step_avg:97.94ms +step:1482/1670 train_time:145152ms step_avg:97.94ms +step:1483/1670 train_time:145250ms step_avg:97.94ms +step:1484/1670 train_time:145348ms step_avg:97.94ms +step:1485/1670 train_time:145693ms step_avg:98.11ms +step:1486/1670 train_time:145768ms step_avg:98.09ms +step:1487/1670 train_time:145865ms step_avg:98.09ms +step:1488/1670 train_time:145961ms step_avg:98.09ms +step:1489/1670 train_time:146058ms step_avg:98.09ms +step:1490/1670 train_time:146154ms step_avg:98.09ms +step:1491/1670 train_time:146251ms step_avg:98.09ms +step:1492/1670 train_time:146347ms step_avg:98.09ms +step:1493/1670 train_time:146444ms step_avg:98.09ms +step:1494/1670 train_time:146540ms step_avg:98.09ms +step:1495/1670 train_time:146643ms step_avg:98.09ms +step:1496/1670 train_time:146746ms step_avg:98.09ms +step:1497/1670 train_time:146847ms step_avg:98.09ms +step:1498/1670 train_time:146945ms step_avg:98.09ms +step:1499/1670 train_time:147042ms step_avg:98.09ms +step:1500/1670 train_time:147140ms step_avg:98.09ms +step:1500/1670 val_loss:3.3126 train_time:147236ms step_avg:98.16ms +step:1501/1670 train_time:147258ms step_avg:98.11ms +step:1502/1670 train_time:147342ms step_avg:98.10ms +step:1503/1670 train_time:147443ms step_avg:98.10ms +step:1504/1670 train_time:147541ms step_avg:98.10ms +step:1505/1670 train_time:147638ms step_avg:98.10ms +step:1506/1670 train_time:147735ms step_avg:98.10ms +step:1507/1670 train_time:147832ms step_avg:98.10ms +step:1508/1670 train_time:147929ms step_avg:98.10ms +step:1509/1670 train_time:148027ms step_avg:98.10ms +step:1510/1670 train_time:148124ms step_avg:98.10ms +step:1511/1670 train_time:148223ms step_avg:98.10ms +step:1512/1670 train_time:148324ms step_avg:98.10ms +step:1513/1670 train_time:148422ms step_avg:98.10ms +step:1514/1670 train_time:148521ms step_avg:98.10ms +step:1515/1670 train_time:148618ms step_avg:98.10ms +step:1516/1670 train_time:148715ms step_avg:98.10ms +step:1517/1670 train_time:148811ms step_avg:98.10ms +step:1518/1670 train_time:148908ms step_avg:98.09ms +step:1519/1670 train_time:149006ms step_avg:98.09ms +step:1520/1670 train_time:149104ms step_avg:98.09ms +step:1521/1670 train_time:149202ms step_avg:98.09ms +step:1522/1670 train_time:149301ms step_avg:98.10ms +step:1523/1670 train_time:149399ms step_avg:98.10ms +step:1524/1670 train_time:149498ms step_avg:98.10ms +step:1525/1670 train_time:149596ms step_avg:98.10ms +step:1526/1670 train_time:149693ms step_avg:98.09ms +step:1527/1670 train_time:149790ms step_avg:98.09ms +step:1528/1670 train_time:149888ms step_avg:98.09ms +step:1529/1670 train_time:149985ms step_avg:98.09ms +step:1530/1670 train_time:150083ms step_avg:98.09ms +step:1531/1670 train_time:150181ms step_avg:98.09ms +step:1532/1670 train_time:150279ms step_avg:98.09ms +step:1533/1670 train_time:150378ms step_avg:98.09ms +step:1534/1670 train_time:150477ms step_avg:98.09ms +step:1535/1670 train_time:150575ms step_avg:98.09ms +step:1536/1670 train_time:150673ms step_avg:98.09ms +step:1537/1670 train_time:150770ms step_avg:98.09ms +step:1538/1670 train_time:150868ms step_avg:98.09ms +step:1539/1670 train_time:150966ms step_avg:98.09ms +step:1540/1670 train_time:151063ms step_avg:98.09ms +step:1541/1670 train_time:151160ms step_avg:98.09ms +step:1542/1670 train_time:151258ms step_avg:98.09ms +step:1543/1670 train_time:151356ms step_avg:98.09ms +step:1544/1670 train_time:151454ms step_avg:98.09ms +step:1545/1670 train_time:151552ms step_avg:98.09ms +step:1546/1670 train_time:151650ms step_avg:98.09ms +step:1547/1670 train_time:151748ms step_avg:98.09ms +step:1548/1670 train_time:151847ms step_avg:98.09ms +step:1549/1670 train_time:151945ms step_avg:98.09ms +step:1550/1670 train_time:152043ms step_avg:98.09ms +step:1551/1670 train_time:152141ms step_avg:98.09ms +step:1552/1670 train_time:152240ms step_avg:98.09ms +step:1553/1670 train_time:152339ms step_avg:98.09ms +step:1554/1670 train_time:152437ms step_avg:98.09ms +step:1555/1670 train_time:152535ms step_avg:98.09ms +step:1556/1670 train_time:152633ms step_avg:98.09ms +step:1557/1670 train_time:152731ms step_avg:98.09ms +step:1558/1670 train_time:152829ms step_avg:98.09ms +step:1559/1670 train_time:152926ms step_avg:98.09ms +step:1560/1670 train_time:153024ms step_avg:98.09ms +step:1561/1670 train_time:153122ms step_avg:98.09ms +step:1562/1670 train_time:153220ms step_avg:98.09ms +step:1563/1670 train_time:153318ms step_avg:98.09ms +step:1564/1670 train_time:153415ms step_avg:98.09ms +step:1565/1670 train_time:153513ms step_avg:98.09ms +step:1566/1670 train_time:153611ms step_avg:98.09ms +step:1567/1670 train_time:153709ms step_avg:98.09ms +step:1568/1670 train_time:153808ms step_avg:98.09ms +step:1569/1670 train_time:153905ms step_avg:98.09ms +step:1570/1670 train_time:154003ms step_avg:98.09ms +step:1571/1670 train_time:154100ms step_avg:98.09ms +step:1572/1670 train_time:154197ms step_avg:98.09ms +step:1573/1670 train_time:154295ms step_avg:98.09ms +step:1574/1670 train_time:154392ms step_avg:98.09ms +step:1575/1670 train_time:154491ms step_avg:98.09ms +step:1576/1670 train_time:154590ms step_avg:98.09ms +step:1577/1670 train_time:154688ms step_avg:98.09ms +step:1578/1670 train_time:154785ms step_avg:98.09ms +step:1579/1670 train_time:154883ms step_avg:98.09ms +step:1580/1670 train_time:154980ms step_avg:98.09ms +step:1581/1670 train_time:155078ms step_avg:98.09ms +step:1582/1670 train_time:155176ms step_avg:98.09ms +step:1583/1670 train_time:155274ms step_avg:98.09ms +step:1584/1670 train_time:155372ms step_avg:98.09ms +step:1585/1670 train_time:155470ms step_avg:98.09ms +step:1586/1670 train_time:155569ms step_avg:98.09ms +step:1587/1670 train_time:155668ms step_avg:98.09ms +step:1588/1670 train_time:155766ms step_avg:98.09ms +step:1589/1670 train_time:155863ms step_avg:98.09ms +step:1590/1670 train_time:155961ms step_avg:98.09ms +step:1591/1670 train_time:156059ms step_avg:98.09ms +step:1592/1670 train_time:156157ms step_avg:98.09ms +step:1593/1670 train_time:156254ms step_avg:98.09ms +step:1594/1670 train_time:156352ms step_avg:98.09ms +step:1595/1670 train_time:156450ms step_avg:98.09ms +step:1596/1670 train_time:156547ms step_avg:98.09ms +step:1597/1670 train_time:156645ms step_avg:98.09ms +step:1598/1670 train_time:156742ms step_avg:98.09ms +step:1599/1670 train_time:156839ms step_avg:98.09ms +step:1600/1670 train_time:156937ms step_avg:98.09ms +step:1601/1670 train_time:157035ms step_avg:98.09ms +step:1602/1670 train_time:157133ms step_avg:98.09ms +step:1603/1670 train_time:157230ms step_avg:98.08ms +step:1604/1670 train_time:157329ms step_avg:98.09ms +step:1605/1670 train_time:157427ms step_avg:98.09ms +step:1606/1670 train_time:157524ms step_avg:98.08ms +step:1607/1670 train_time:157622ms step_avg:98.08ms +step:1608/1670 train_time:157719ms step_avg:98.08ms +step:1609/1670 train_time:157817ms step_avg:98.08ms +step:1610/1670 train_time:157915ms step_avg:98.08ms +step:1611/1670 train_time:158013ms step_avg:98.08ms +step:1612/1670 train_time:158110ms step_avg:98.08ms +step:1613/1670 train_time:158208ms step_avg:98.08ms +step:1614/1670 train_time:158307ms step_avg:98.08ms +step:1615/1670 train_time:158405ms step_avg:98.08ms +step:1616/1670 train_time:158502ms step_avg:98.08ms +step:1617/1670 train_time:158599ms step_avg:98.08ms +step:1618/1670 train_time:158697ms step_avg:98.08ms +step:1619/1670 train_time:158795ms step_avg:98.08ms +step:1620/1670 train_time:158893ms step_avg:98.08ms +step:1621/1670 train_time:158993ms step_avg:98.08ms +step:1622/1670 train_time:159092ms step_avg:98.08ms +step:1623/1670 train_time:159189ms step_avg:98.08ms +step:1624/1670 train_time:159287ms step_avg:98.08ms +step:1625/1670 train_time:159386ms step_avg:98.08ms +step:1625/1670 val_loss:3.2856 train_time:159483ms step_avg:98.14ms +step:1626/1670 train_time:159507ms step_avg:98.10ms +step:1627/1670 train_time:159588ms step_avg:98.09ms +step:1628/1670 train_time:159688ms step_avg:98.09ms +step:1629/1670 train_time:159786ms step_avg:98.09ms +step:1630/1670 train_time:159883ms step_avg:98.09ms +step:1631/1670 train_time:159980ms step_avg:98.09ms +step:1632/1670 train_time:160077ms step_avg:98.09ms +step:1633/1670 train_time:160175ms step_avg:98.09ms +step:1634/1670 train_time:160272ms step_avg:98.09ms +step:1635/1670 train_time:160368ms step_avg:98.08ms +step:1636/1670 train_time:160467ms step_avg:98.08ms +step:1637/1670 train_time:160567ms step_avg:98.09ms +step:1638/1670 train_time:160668ms step_avg:98.09ms +step:1639/1670 train_time:160767ms step_avg:98.09ms +step:1640/1670 train_time:160864ms step_avg:98.09ms +step:1641/1670 train_time:160962ms step_avg:98.09ms +step:1642/1670 train_time:161058ms step_avg:98.09ms +step:1643/1670 train_time:161156ms step_avg:98.09ms +step:1644/1670 train_time:161253ms step_avg:98.09ms +step:1645/1670 train_time:161350ms step_avg:98.09ms +step:1646/1670 train_time:161449ms step_avg:98.09ms +step:1647/1670 train_time:161548ms step_avg:98.09ms +step:1648/1670 train_time:161647ms step_avg:98.09ms +step:1649/1670 train_time:161746ms step_avg:98.09ms +step:1650/1670 train_time:161843ms step_avg:98.09ms +step:1651/1670 train_time:161940ms step_avg:98.09ms +step:1652/1670 train_time:162038ms step_avg:98.09ms +step:1653/1670 train_time:162136ms step_avg:98.09ms +step:1654/1670 train_time:162234ms step_avg:98.09ms +step:1655/1670 train_time:162332ms step_avg:98.09ms +step:1656/1670 train_time:162429ms step_avg:98.09ms +step:1657/1670 train_time:162528ms step_avg:98.09ms +step:1658/1670 train_time:162626ms step_avg:98.09ms +step:1659/1670 train_time:162725ms step_avg:98.09ms +step:1660/1670 train_time:162822ms step_avg:98.09ms +step:1661/1670 train_time:162920ms step_avg:98.09ms +step:1662/1670 train_time:163018ms step_avg:98.09ms +step:1663/1670 train_time:163117ms step_avg:98.09ms +step:1664/1670 train_time:163215ms step_avg:98.09ms +step:1665/1670 train_time:163312ms step_avg:98.09ms +step:1666/1670 train_time:163410ms step_avg:98.08ms +step:1667/1670 train_time:163507ms step_avg:98.08ms +step:1668/1670 train_time:163605ms step_avg:98.08ms +step:1669/1670 train_time:163703ms step_avg:98.08ms +step:1670/1670 train_time:163800ms step_avg:98.08ms +step:1670/1670 val_loss:3.2778 train_time:163897ms step_avg:98.14ms +peak memory allocated: 34217 MiB reserved: 49676 MiB diff --git a/records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt b/records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt new file mode 100644 index 000000000..f800a7966 --- /dev/null +++ b/records/090325_FA3/d5d05889-69c7-4887-ac9b-baaae1a5f499.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 19:35:35 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 35C P0 125W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 30C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 47396 C /usr/bin/python 0MiB | +| 0 N/A N/A 47397 C /usr/bin/python 0MiB | +| 0 N/A N/A 47398 C /usr/bin/python 0MiB | +| 0 N/A N/A 47399 C /usr/bin/python 0MiB | +| 0 N/A N/A 47400 C /usr/bin/python 0MiB | +| 0 N/A N/A 47401 C /usr/bin/python 0MiB | +| 0 N/A N/A 47402 C /usr/bin/python 0MiB | +| 0 N/A N/A 47403 C /usr/bin/python 0MiB | +| 1 N/A N/A 47397 C /usr/bin/python 0MiB | +| 2 N/A N/A 47398 C /usr/bin/python 0MiB | +| 3 N/A N/A 47399 C /usr/bin/python 0MiB | +| 4 N/A N/A 47400 C /usr/bin/python 0MiB | +| 5 N/A N/A 47401 C /usr/bin/python 0MiB | +| 6 N/A N/A 47402 C /usr/bin/python 0MiB | +| 7 N/A N/A 47403 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:473ms step_avg:472.59ms +step:2/1670 train_time:494ms step_avg:247.17ms +step:3/1670 train_time:566ms step_avg:188.74ms +step:4/1670 train_time:659ms step_avg:164.86ms +step:5/1670 train_time:754ms step_avg:150.77ms +step:6/1670 train_time:848ms step_avg:141.33ms +step:7/1670 train_time:943ms step_avg:134.68ms +step:8/1670 train_time:1038ms step_avg:129.80ms +step:9/1670 train_time:1133ms step_avg:125.94ms +step:10/1670 train_time:1228ms step_avg:122.75ms +step:11/1670 train_time:1322ms step_avg:120.19ms +step:12/1670 train_time:1421ms step_avg:118.45ms +step:13/1670 train_time:1523ms step_avg:117.16ms +step:14/1670 train_time:1621ms step_avg:115.75ms +step:15/1670 train_time:1716ms step_avg:114.40ms +step:16/1670 train_time:1811ms step_avg:113.19ms +step:17/1670 train_time:1906ms step_avg:112.12ms +step:18/1670 train_time:2001ms step_avg:111.15ms +step:19/1670 train_time:2096ms step_avg:110.34ms +step:20/1670 train_time:2191ms step_avg:109.57ms +step:21/1670 train_time:2287ms step_avg:108.89ms +step:22/1670 train_time:2383ms step_avg:108.31ms +step:23/1670 train_time:2479ms step_avg:107.79ms +step:24/1670 train_time:2576ms step_avg:107.35ms +step:25/1670 train_time:2673ms step_avg:106.92ms +step:26/1670 train_time:2768ms step_avg:106.47ms +step:27/1670 train_time:2863ms step_avg:106.06ms +step:28/1670 train_time:2959ms step_avg:105.70ms +step:29/1670 train_time:3055ms step_avg:105.34ms +step:30/1670 train_time:3151ms step_avg:105.03ms +step:31/1670 train_time:3247ms step_avg:104.73ms +step:32/1670 train_time:3342ms step_avg:104.43ms +step:33/1670 train_time:3438ms step_avg:104.17ms +step:34/1670 train_time:3534ms step_avg:103.94ms +step:35/1670 train_time:3630ms step_avg:103.72ms +step:36/1670 train_time:3727ms step_avg:103.52ms +step:37/1670 train_time:3823ms step_avg:103.31ms +step:38/1670 train_time:3918ms step_avg:103.11ms +step:39/1670 train_time:4014ms step_avg:102.91ms +step:40/1670 train_time:4109ms step_avg:102.73ms +step:41/1670 train_time:4204ms step_avg:102.55ms +step:42/1670 train_time:4300ms step_avg:102.38ms +step:43/1670 train_time:4396ms step_avg:102.23ms +step:44/1670 train_time:4492ms step_avg:102.09ms +step:45/1670 train_time:4587ms step_avg:101.94ms +step:46/1670 train_time:4684ms step_avg:101.83ms +step:47/1670 train_time:4780ms step_avg:101.71ms +step:48/1670 train_time:4877ms step_avg:101.59ms +step:49/1670 train_time:4973ms step_avg:101.48ms +step:50/1670 train_time:5068ms step_avg:101.36ms +step:51/1670 train_time:5164ms step_avg:101.25ms +step:52/1670 train_time:5259ms step_avg:101.13ms +step:53/1670 train_time:5355ms step_avg:101.03ms +step:54/1670 train_time:5451ms step_avg:100.94ms +step:55/1670 train_time:5547ms step_avg:100.85ms +step:56/1670 train_time:5643ms step_avg:100.77ms +step:57/1670 train_time:5739ms step_avg:100.68ms +step:58/1670 train_time:5834ms step_avg:100.59ms +step:59/1670 train_time:5930ms step_avg:100.51ms +step:60/1670 train_time:6026ms step_avg:100.43ms +step:61/1670 train_time:6122ms step_avg:100.36ms +step:62/1670 train_time:6218ms step_avg:100.29ms +step:63/1670 train_time:6313ms step_avg:100.21ms +step:64/1670 train_time:6408ms step_avg:100.13ms +step:65/1670 train_time:6504ms step_avg:100.05ms +step:66/1670 train_time:6600ms step_avg:100.00ms +step:67/1670 train_time:6695ms step_avg:99.93ms +step:68/1670 train_time:6791ms step_avg:99.87ms +step:69/1670 train_time:6887ms step_avg:99.81ms +step:70/1670 train_time:6983ms step_avg:99.76ms +step:71/1670 train_time:7079ms step_avg:99.70ms +step:72/1670 train_time:7176ms step_avg:99.66ms +step:73/1670 train_time:7272ms step_avg:99.62ms +step:74/1670 train_time:7367ms step_avg:99.56ms +step:75/1670 train_time:7463ms step_avg:99.51ms +step:76/1670 train_time:7559ms step_avg:99.46ms +step:77/1670 train_time:7655ms step_avg:99.42ms +step:78/1670 train_time:7752ms step_avg:99.38ms +step:79/1670 train_time:7847ms step_avg:99.32ms +step:80/1670 train_time:7943ms step_avg:99.28ms +step:81/1670 train_time:8038ms step_avg:99.24ms +step:82/1670 train_time:8135ms step_avg:99.21ms +step:83/1670 train_time:8230ms step_avg:99.16ms +step:84/1670 train_time:8325ms step_avg:99.11ms +step:85/1670 train_time:8421ms step_avg:99.07ms +step:86/1670 train_time:8517ms step_avg:99.04ms +step:87/1670 train_time:8613ms step_avg:99.00ms +step:88/1670 train_time:8709ms step_avg:98.97ms +step:89/1670 train_time:8804ms step_avg:98.93ms +step:90/1670 train_time:8900ms step_avg:98.89ms +step:91/1670 train_time:8995ms step_avg:98.85ms +step:92/1670 train_time:9091ms step_avg:98.81ms +step:93/1670 train_time:9186ms step_avg:98.78ms +step:94/1670 train_time:9282ms step_avg:98.74ms +step:95/1670 train_time:9379ms step_avg:98.72ms +step:96/1670 train_time:9476ms step_avg:98.70ms +step:97/1670 train_time:9572ms step_avg:98.68ms +step:98/1670 train_time:9667ms step_avg:98.65ms +step:99/1670 train_time:9763ms step_avg:98.62ms +step:100/1670 train_time:9859ms step_avg:98.59ms +step:101/1670 train_time:9955ms step_avg:98.57ms +step:102/1670 train_time:10051ms step_avg:98.54ms +step:103/1670 train_time:10146ms step_avg:98.51ms +step:104/1670 train_time:10241ms step_avg:98.48ms +step:105/1670 train_time:10337ms step_avg:98.45ms +step:106/1670 train_time:10433ms step_avg:98.43ms +step:107/1670 train_time:10529ms step_avg:98.40ms +step:108/1670 train_time:10624ms step_avg:98.37ms +step:109/1670 train_time:10720ms step_avg:98.35ms +step:110/1670 train_time:10816ms step_avg:98.33ms +step:111/1670 train_time:10912ms step_avg:98.31ms +step:112/1670 train_time:11009ms step_avg:98.29ms +step:113/1670 train_time:11104ms step_avg:98.27ms +step:114/1670 train_time:11199ms step_avg:98.24ms +step:115/1670 train_time:11295ms step_avg:98.21ms +step:116/1670 train_time:11390ms step_avg:98.19ms +step:117/1670 train_time:11485ms step_avg:98.17ms +step:118/1670 train_time:11581ms step_avg:98.14ms +step:119/1670 train_time:11677ms step_avg:98.12ms +step:120/1670 train_time:11773ms step_avg:98.10ms +step:121/1670 train_time:11868ms step_avg:98.08ms +step:122/1670 train_time:11964ms step_avg:98.06ms +step:123/1670 train_time:12060ms step_avg:98.05ms +step:124/1670 train_time:12156ms step_avg:98.03ms +step:125/1670 train_time:12251ms step_avg:98.01ms +step:125/1670 val_loss:4.3007 train_time:12346ms step_avg:98.77ms +step:126/1670 train_time:12367ms step_avg:98.15ms +step:127/1670 train_time:12446ms step_avg:98.00ms +step:128/1670 train_time:12552ms step_avg:98.06ms +step:129/1670 train_time:12649ms step_avg:98.06ms +step:130/1670 train_time:12745ms step_avg:98.04ms +step:131/1670 train_time:12840ms step_avg:98.01ms +step:132/1670 train_time:12934ms step_avg:97.99ms +step:133/1670 train_time:13029ms step_avg:97.96ms +step:134/1670 train_time:13123ms step_avg:97.94ms +step:135/1670 train_time:13219ms step_avg:97.92ms +step:136/1670 train_time:13313ms step_avg:97.89ms +step:137/1670 train_time:13409ms step_avg:97.88ms +step:138/1670 train_time:13507ms step_avg:97.88ms +step:139/1670 train_time:13605ms step_avg:97.88ms +step:140/1670 train_time:13703ms step_avg:97.88ms +step:141/1670 train_time:13797ms step_avg:97.85ms +step:142/1670 train_time:13892ms step_avg:97.83ms +step:143/1670 train_time:13988ms step_avg:97.82ms +step:144/1670 train_time:14082ms step_avg:97.79ms +step:145/1670 train_time:14177ms step_avg:97.77ms +step:146/1670 train_time:14272ms step_avg:97.75ms +step:147/1670 train_time:14367ms step_avg:97.74ms +step:148/1670 train_time:14463ms step_avg:97.73ms +step:149/1670 train_time:14560ms step_avg:97.72ms +step:150/1670 train_time:14657ms step_avg:97.71ms +step:151/1670 train_time:14752ms step_avg:97.70ms +step:152/1670 train_time:14848ms step_avg:97.68ms +step:153/1670 train_time:14943ms step_avg:97.67ms +step:154/1670 train_time:15038ms step_avg:97.65ms +step:155/1670 train_time:15133ms step_avg:97.63ms +step:156/1670 train_time:15228ms step_avg:97.61ms +step:157/1670 train_time:15323ms step_avg:97.60ms +step:158/1670 train_time:15419ms step_avg:97.59ms +step:159/1670 train_time:15515ms step_avg:97.58ms +step:160/1670 train_time:15611ms step_avg:97.57ms +step:161/1670 train_time:15707ms step_avg:97.56ms +step:162/1670 train_time:15803ms step_avg:97.55ms +step:163/1670 train_time:15899ms step_avg:97.54ms +step:164/1670 train_time:15995ms step_avg:97.53ms +step:165/1670 train_time:16090ms step_avg:97.51ms +step:166/1670 train_time:16184ms step_avg:97.50ms +step:167/1670 train_time:16279ms step_avg:97.48ms +step:168/1670 train_time:16374ms step_avg:97.47ms +step:169/1670 train_time:16470ms step_avg:97.46ms +step:170/1670 train_time:16567ms step_avg:97.45ms +step:171/1670 train_time:16663ms step_avg:97.45ms +step:172/1670 train_time:16759ms step_avg:97.44ms +step:173/1670 train_time:16855ms step_avg:97.43ms +step:174/1670 train_time:16950ms step_avg:97.41ms +step:175/1670 train_time:17046ms step_avg:97.40ms +step:176/1670 train_time:17141ms step_avg:97.39ms +step:177/1670 train_time:17237ms step_avg:97.38ms +step:178/1670 train_time:17331ms step_avg:97.37ms +step:179/1670 train_time:17427ms step_avg:97.36ms +step:180/1670 train_time:17522ms step_avg:97.35ms +step:181/1670 train_time:17618ms step_avg:97.34ms +step:182/1670 train_time:17714ms step_avg:97.33ms +step:183/1670 train_time:17809ms step_avg:97.32ms +step:184/1670 train_time:17906ms step_avg:97.31ms +step:185/1670 train_time:18001ms step_avg:97.30ms +step:186/1670 train_time:18097ms step_avg:97.29ms +step:187/1670 train_time:18192ms step_avg:97.28ms +step:188/1670 train_time:18287ms step_avg:97.27ms +step:189/1670 train_time:18382ms step_avg:97.26ms +step:190/1670 train_time:18478ms step_avg:97.25ms +step:191/1670 train_time:18574ms step_avg:97.25ms +step:192/1670 train_time:18669ms step_avg:97.23ms +step:193/1670 train_time:18766ms step_avg:97.23ms +step:194/1670 train_time:18862ms step_avg:97.23ms +step:195/1670 train_time:18957ms step_avg:97.22ms +step:196/1670 train_time:19053ms step_avg:97.21ms +step:197/1670 train_time:19148ms step_avg:97.20ms +step:198/1670 train_time:19244ms step_avg:97.19ms +step:199/1670 train_time:19340ms step_avg:97.18ms +step:200/1670 train_time:19435ms step_avg:97.17ms +step:201/1670 train_time:19530ms step_avg:97.17ms +step:202/1670 train_time:19625ms step_avg:97.15ms +step:203/1670 train_time:19721ms step_avg:97.15ms +step:204/1670 train_time:19817ms step_avg:97.14ms +step:205/1670 train_time:19913ms step_avg:97.14ms +step:206/1670 train_time:20007ms step_avg:97.12ms +step:207/1670 train_time:20103ms step_avg:97.12ms +step:208/1670 train_time:20199ms step_avg:97.11ms +step:209/1670 train_time:20294ms step_avg:97.10ms +step:210/1670 train_time:20389ms step_avg:97.09ms +step:211/1670 train_time:20485ms step_avg:97.09ms +step:212/1670 train_time:20579ms step_avg:97.07ms +step:213/1670 train_time:20860ms step_avg:97.93ms +step:214/1670 train_time:20955ms step_avg:97.92ms +step:215/1670 train_time:21048ms step_avg:97.90ms +step:216/1670 train_time:21143ms step_avg:97.88ms +step:217/1670 train_time:21238ms step_avg:97.87ms +step:218/1670 train_time:21333ms step_avg:97.86ms +step:219/1670 train_time:21427ms step_avg:97.84ms +step:220/1670 train_time:21522ms step_avg:97.83ms +step:221/1670 train_time:21617ms step_avg:97.81ms +step:222/1670 train_time:21711ms step_avg:97.80ms +step:223/1670 train_time:21808ms step_avg:97.79ms +step:224/1670 train_time:21908ms step_avg:97.80ms +step:225/1670 train_time:22005ms step_avg:97.80ms +step:226/1670 train_time:22101ms step_avg:97.79ms +step:227/1670 train_time:22196ms step_avg:97.78ms +step:228/1670 train_time:22291ms step_avg:97.77ms +step:229/1670 train_time:22385ms step_avg:97.75ms +step:230/1670 train_time:22481ms step_avg:97.74ms +step:231/1670 train_time:22575ms step_avg:97.73ms +step:232/1670 train_time:22671ms step_avg:97.72ms +step:233/1670 train_time:22766ms step_avg:97.71ms +step:234/1670 train_time:22863ms step_avg:97.71ms +step:235/1670 train_time:22961ms step_avg:97.71ms +step:236/1670 train_time:23058ms step_avg:97.70ms +step:237/1670 train_time:23154ms step_avg:97.70ms +step:238/1670 train_time:23250ms step_avg:97.69ms +step:239/1670 train_time:23345ms step_avg:97.68ms +step:240/1670 train_time:23440ms step_avg:97.67ms +step:241/1670 train_time:23535ms step_avg:97.65ms +step:242/1670 train_time:23630ms step_avg:97.64ms +step:243/1670 train_time:23725ms step_avg:97.63ms +step:244/1670 train_time:23821ms step_avg:97.63ms +step:245/1670 train_time:23918ms step_avg:97.62ms +step:246/1670 train_time:24014ms step_avg:97.62ms +step:247/1670 train_time:24109ms step_avg:97.61ms +step:248/1670 train_time:24205ms step_avg:97.60ms +step:249/1670 train_time:24301ms step_avg:97.59ms +step:250/1670 train_time:24396ms step_avg:97.58ms +step:250/1670 val_loss:3.9790 train_time:24490ms step_avg:97.96ms +step:251/1670 train_time:24515ms step_avg:97.67ms +step:252/1670 train_time:24593ms step_avg:97.59ms +step:253/1670 train_time:24693ms step_avg:97.60ms +step:254/1670 train_time:24789ms step_avg:97.60ms +step:255/1670 train_time:24884ms step_avg:97.59ms +step:256/1670 train_time:24979ms step_avg:97.57ms +step:257/1670 train_time:25074ms step_avg:97.56ms +step:258/1670 train_time:25168ms step_avg:97.55ms +step:259/1670 train_time:25263ms step_avg:97.54ms +step:260/1670 train_time:25358ms step_avg:97.53ms +step:261/1670 train_time:25454ms step_avg:97.52ms +step:262/1670 train_time:25552ms step_avg:97.53ms +step:263/1670 train_time:25650ms step_avg:97.53ms +step:264/1670 train_time:25747ms step_avg:97.53ms +step:265/1670 train_time:25842ms step_avg:97.52ms +step:266/1670 train_time:25938ms step_avg:97.51ms +step:267/1670 train_time:26032ms step_avg:97.50ms +step:268/1670 train_time:26127ms step_avg:97.49ms +step:269/1670 train_time:26222ms step_avg:97.48ms +step:270/1670 train_time:26316ms step_avg:97.47ms +step:271/1670 train_time:26412ms step_avg:97.46ms +step:272/1670 train_time:26510ms step_avg:97.46ms +step:273/1670 train_time:26605ms step_avg:97.45ms +step:274/1670 train_time:26701ms step_avg:97.45ms +step:275/1670 train_time:26798ms step_avg:97.45ms +step:276/1670 train_time:26894ms step_avg:97.44ms +step:277/1670 train_time:26990ms step_avg:97.44ms +step:278/1670 train_time:27084ms step_avg:97.42ms +step:279/1670 train_time:27180ms step_avg:97.42ms +step:280/1670 train_time:27275ms step_avg:97.41ms +step:281/1670 train_time:27370ms step_avg:97.40ms +step:282/1670 train_time:27465ms step_avg:97.39ms +step:283/1670 train_time:27561ms step_avg:97.39ms +step:284/1670 train_time:27658ms step_avg:97.39ms +step:285/1670 train_time:27755ms step_avg:97.39ms +step:286/1670 train_time:27853ms step_avg:97.39ms +step:287/1670 train_time:27949ms step_avg:97.38ms +step:288/1670 train_time:28044ms step_avg:97.37ms +step:289/1670 train_time:28139ms step_avg:97.37ms +step:290/1670 train_time:28234ms step_avg:97.36ms +step:291/1670 train_time:28329ms step_avg:97.35ms +step:292/1670 train_time:28425ms step_avg:97.34ms +step:293/1670 train_time:28520ms step_avg:97.34ms +step:294/1670 train_time:28616ms step_avg:97.33ms +step:295/1670 train_time:28713ms step_avg:97.33ms +step:296/1670 train_time:28809ms step_avg:97.33ms +step:297/1670 train_time:28904ms step_avg:97.32ms +step:298/1670 train_time:29000ms step_avg:97.32ms +step:299/1670 train_time:29097ms step_avg:97.31ms +step:300/1670 train_time:29192ms step_avg:97.31ms +step:301/1670 train_time:29287ms step_avg:97.30ms +step:302/1670 train_time:29383ms step_avg:97.29ms +step:303/1670 train_time:29478ms step_avg:97.29ms +step:304/1670 train_time:29574ms step_avg:97.28ms +step:305/1670 train_time:29671ms step_avg:97.28ms +step:306/1670 train_time:29767ms step_avg:97.28ms +step:307/1670 train_time:29863ms step_avg:97.27ms +step:308/1670 train_time:29959ms step_avg:97.27ms +step:309/1670 train_time:30056ms step_avg:97.27ms +step:310/1670 train_time:30152ms step_avg:97.27ms +step:311/1670 train_time:30248ms step_avg:97.26ms +step:312/1670 train_time:30343ms step_avg:97.25ms +step:313/1670 train_time:30437ms step_avg:97.24ms +step:314/1670 train_time:30532ms step_avg:97.24ms +step:315/1670 train_time:30628ms step_avg:97.23ms +step:316/1670 train_time:30724ms step_avg:97.23ms +step:317/1670 train_time:30819ms step_avg:97.22ms +step:318/1670 train_time:30916ms step_avg:97.22ms +step:319/1670 train_time:31013ms step_avg:97.22ms +step:320/1670 train_time:31109ms step_avg:97.22ms +step:321/1670 train_time:31204ms step_avg:97.21ms +step:322/1670 train_time:31299ms step_avg:97.20ms +step:323/1670 train_time:31395ms step_avg:97.20ms +step:324/1670 train_time:31491ms step_avg:97.19ms +step:325/1670 train_time:31586ms step_avg:97.19ms +step:326/1670 train_time:31681ms step_avg:97.18ms +step:327/1670 train_time:31777ms step_avg:97.18ms +step:328/1670 train_time:31872ms step_avg:97.17ms +step:329/1670 train_time:31968ms step_avg:97.17ms +step:330/1670 train_time:32063ms step_avg:97.16ms +step:331/1670 train_time:32160ms step_avg:97.16ms +step:332/1670 train_time:32255ms step_avg:97.16ms +step:333/1670 train_time:32351ms step_avg:97.15ms +step:334/1670 train_time:32447ms step_avg:97.15ms +step:335/1670 train_time:32542ms step_avg:97.14ms +step:336/1670 train_time:32638ms step_avg:97.14ms +step:337/1670 train_time:32733ms step_avg:97.13ms +step:338/1670 train_time:32829ms step_avg:97.13ms +step:339/1670 train_time:32925ms step_avg:97.12ms +step:340/1670 train_time:33020ms step_avg:97.12ms +step:341/1670 train_time:33116ms step_avg:97.11ms +step:342/1670 train_time:33212ms step_avg:97.11ms +step:343/1670 train_time:33308ms step_avg:97.11ms +step:344/1670 train_time:33402ms step_avg:97.10ms +step:345/1670 train_time:33497ms step_avg:97.09ms +step:346/1670 train_time:33593ms step_avg:97.09ms +step:347/1670 train_time:33689ms step_avg:97.09ms +step:348/1670 train_time:33784ms step_avg:97.08ms +step:349/1670 train_time:33880ms step_avg:97.08ms +step:350/1670 train_time:33976ms step_avg:97.07ms +step:351/1670 train_time:34072ms step_avg:97.07ms +step:352/1670 train_time:34167ms step_avg:97.07ms +step:353/1670 train_time:34263ms step_avg:97.06ms +step:354/1670 train_time:34358ms step_avg:97.06ms +step:355/1670 train_time:34454ms step_avg:97.05ms +step:356/1670 train_time:34550ms step_avg:97.05ms +step:357/1670 train_time:34645ms step_avg:97.05ms +step:358/1670 train_time:34741ms step_avg:97.04ms +step:359/1670 train_time:34836ms step_avg:97.04ms +step:360/1670 train_time:34932ms step_avg:97.03ms +step:361/1670 train_time:35029ms step_avg:97.03ms +step:362/1670 train_time:35124ms step_avg:97.03ms +step:363/1670 train_time:35220ms step_avg:97.02ms +step:364/1670 train_time:35316ms step_avg:97.02ms +step:365/1670 train_time:35412ms step_avg:97.02ms +step:366/1670 train_time:35507ms step_avg:97.01ms +step:367/1670 train_time:35602ms step_avg:97.01ms +step:368/1670 train_time:35697ms step_avg:97.00ms +step:369/1670 train_time:35794ms step_avg:97.00ms +step:370/1670 train_time:35889ms step_avg:97.00ms +step:371/1670 train_time:35984ms step_avg:96.99ms +step:372/1670 train_time:36080ms step_avg:96.99ms +step:373/1670 train_time:36177ms step_avg:96.99ms +step:374/1670 train_time:36273ms step_avg:96.99ms +step:375/1670 train_time:36369ms step_avg:96.98ms +step:375/1670 val_loss:3.8252 train_time:36465ms step_avg:97.24ms +step:376/1670 train_time:36489ms step_avg:97.04ms +step:377/1670 train_time:36567ms step_avg:96.99ms +step:378/1670 train_time:36663ms step_avg:96.99ms +step:379/1670 train_time:36759ms step_avg:96.99ms +step:380/1670 train_time:36854ms step_avg:96.98ms +step:381/1670 train_time:36949ms step_avg:96.98ms +step:382/1670 train_time:37043ms step_avg:96.97ms +step:383/1670 train_time:37138ms step_avg:96.97ms +step:384/1670 train_time:37233ms step_avg:96.96ms +step:385/1670 train_time:37328ms step_avg:96.95ms +step:386/1670 train_time:37425ms step_avg:96.95ms +step:387/1670 train_time:37522ms step_avg:96.96ms +step:388/1670 train_time:37620ms step_avg:96.96ms +step:389/1670 train_time:37717ms step_avg:96.96ms +step:390/1670 train_time:37812ms step_avg:96.95ms +step:391/1670 train_time:37907ms step_avg:96.95ms +step:392/1670 train_time:38002ms step_avg:96.94ms +step:393/1670 train_time:38098ms step_avg:96.94ms +step:394/1670 train_time:38193ms step_avg:96.94ms +step:395/1670 train_time:38287ms step_avg:96.93ms +step:396/1670 train_time:38382ms step_avg:96.92ms +step:397/1670 train_time:38479ms step_avg:96.92ms +step:398/1670 train_time:38576ms step_avg:96.93ms +step:399/1670 train_time:38673ms step_avg:96.92ms +step:400/1670 train_time:38768ms step_avg:96.92ms +step:401/1670 train_time:38864ms step_avg:96.92ms +step:402/1670 train_time:38959ms step_avg:96.91ms +step:403/1670 train_time:39054ms step_avg:96.91ms +step:404/1670 train_time:39150ms step_avg:96.90ms +step:405/1670 train_time:39245ms step_avg:96.90ms +step:406/1670 train_time:39340ms step_avg:96.90ms +step:407/1670 train_time:39436ms step_avg:96.89ms +step:408/1670 train_time:39533ms step_avg:96.89ms +step:409/1670 train_time:39629ms step_avg:96.89ms +step:410/1670 train_time:39725ms step_avg:96.89ms +step:411/1670 train_time:39820ms step_avg:96.89ms +step:412/1670 train_time:39917ms step_avg:96.89ms +step:413/1670 train_time:40013ms step_avg:96.88ms +step:414/1670 train_time:40108ms step_avg:96.88ms +step:415/1670 train_time:40203ms step_avg:96.87ms +step:416/1670 train_time:40299ms step_avg:96.87ms +step:417/1670 train_time:40395ms step_avg:96.87ms +step:418/1670 train_time:40491ms step_avg:96.87ms +step:419/1670 train_time:40587ms step_avg:96.87ms +step:420/1670 train_time:40682ms step_avg:96.86ms +step:421/1670 train_time:40777ms step_avg:96.86ms +step:422/1670 train_time:40873ms step_avg:96.86ms +step:423/1670 train_time:40969ms step_avg:96.85ms +step:424/1670 train_time:41065ms step_avg:96.85ms +step:425/1670 train_time:41353ms step_avg:97.30ms +step:426/1670 train_time:41446ms step_avg:97.29ms +step:427/1670 train_time:41540ms step_avg:97.28ms +step:428/1670 train_time:41634ms step_avg:97.28ms +step:429/1670 train_time:41729ms step_avg:97.27ms +step:430/1670 train_time:41824ms step_avg:97.26ms +step:431/1670 train_time:41918ms step_avg:97.26ms +step:432/1670 train_time:42013ms step_avg:97.25ms +step:433/1670 train_time:42108ms step_avg:97.25ms +step:434/1670 train_time:42203ms step_avg:97.24ms +step:435/1670 train_time:42304ms step_avg:97.25ms +step:436/1670 train_time:42401ms step_avg:97.25ms +step:437/1670 train_time:42498ms step_avg:97.25ms +step:438/1670 train_time:42594ms step_avg:97.25ms +step:439/1670 train_time:42689ms step_avg:97.24ms +step:440/1670 train_time:42784ms step_avg:97.24ms +step:441/1670 train_time:42879ms step_avg:97.23ms +step:442/1670 train_time:42974ms step_avg:97.23ms +step:443/1670 train_time:43068ms step_avg:97.22ms +step:444/1670 train_time:43163ms step_avg:97.21ms +step:445/1670 train_time:43260ms step_avg:97.21ms +step:446/1670 train_time:43357ms step_avg:97.21ms +step:447/1670 train_time:43453ms step_avg:97.21ms +step:448/1670 train_time:43549ms step_avg:97.21ms +step:449/1670 train_time:43644ms step_avg:97.20ms +step:450/1670 train_time:43740ms step_avg:97.20ms +step:451/1670 train_time:43835ms step_avg:97.19ms +step:452/1670 train_time:43930ms step_avg:97.19ms +step:453/1670 train_time:44024ms step_avg:97.18ms +step:454/1670 train_time:44120ms step_avg:97.18ms +step:455/1670 train_time:44216ms step_avg:97.18ms +step:456/1670 train_time:44313ms step_avg:97.18ms +step:457/1670 train_time:44409ms step_avg:97.18ms +step:458/1670 train_time:44506ms step_avg:97.17ms +step:459/1670 train_time:44601ms step_avg:97.17ms +step:460/1670 train_time:44697ms step_avg:97.17ms +step:461/1670 train_time:44793ms step_avg:97.16ms +step:462/1670 train_time:44887ms step_avg:97.16ms +step:463/1670 train_time:44982ms step_avg:97.15ms +step:464/1670 train_time:45078ms step_avg:97.15ms +step:465/1670 train_time:45174ms step_avg:97.15ms +step:466/1670 train_time:45270ms step_avg:97.15ms +step:467/1670 train_time:45366ms step_avg:97.14ms +step:468/1670 train_time:45461ms step_avg:97.14ms +step:469/1670 train_time:45557ms step_avg:97.14ms +step:470/1670 train_time:45653ms step_avg:97.13ms +step:471/1670 train_time:45749ms step_avg:97.13ms +step:472/1670 train_time:45845ms step_avg:97.13ms +step:473/1670 train_time:45940ms step_avg:97.12ms +step:474/1670 train_time:46035ms step_avg:97.12ms +step:475/1670 train_time:46130ms step_avg:97.12ms +step:476/1670 train_time:46226ms step_avg:97.11ms +step:477/1670 train_time:46322ms step_avg:97.11ms +step:478/1670 train_time:46418ms step_avg:97.11ms +step:479/1670 train_time:46515ms step_avg:97.11ms +step:480/1670 train_time:46611ms step_avg:97.11ms +step:481/1670 train_time:46706ms step_avg:97.10ms +step:482/1670 train_time:46801ms step_avg:97.10ms +step:483/1670 train_time:46897ms step_avg:97.10ms +step:484/1670 train_time:46993ms step_avg:97.09ms +step:485/1670 train_time:47087ms step_avg:97.09ms +step:486/1670 train_time:47183ms step_avg:97.08ms +step:487/1670 train_time:47278ms step_avg:97.08ms +step:488/1670 train_time:47375ms step_avg:97.08ms +step:489/1670 train_time:47471ms step_avg:97.08ms +step:490/1670 train_time:47567ms step_avg:97.08ms +step:491/1670 train_time:47662ms step_avg:97.07ms +step:492/1670 train_time:47758ms step_avg:97.07ms +step:493/1670 train_time:47854ms step_avg:97.07ms +step:494/1670 train_time:47950ms step_avg:97.06ms +step:495/1670 train_time:48045ms step_avg:97.06ms +step:496/1670 train_time:48140ms step_avg:97.06ms +step:497/1670 train_time:48236ms step_avg:97.05ms +step:498/1670 train_time:48332ms step_avg:97.05ms +step:499/1670 train_time:48428ms step_avg:97.05ms +step:500/1670 train_time:48523ms step_avg:97.05ms +step:500/1670 val_loss:3.7242 train_time:48619ms step_avg:97.24ms +step:501/1670 train_time:48643ms step_avg:97.09ms +step:502/1670 train_time:48722ms step_avg:97.06ms +step:503/1670 train_time:48821ms step_avg:97.06ms +step:504/1670 train_time:48916ms step_avg:97.06ms +step:505/1670 train_time:49011ms step_avg:97.05ms +step:506/1670 train_time:49106ms step_avg:97.05ms +step:507/1670 train_time:49201ms step_avg:97.04ms +step:508/1670 train_time:49295ms step_avg:97.04ms +step:509/1670 train_time:49390ms step_avg:97.03ms +step:510/1670 train_time:49485ms step_avg:97.03ms +step:511/1670 train_time:49581ms step_avg:97.03ms +step:512/1670 train_time:49678ms step_avg:97.03ms +step:513/1670 train_time:49775ms step_avg:97.03ms +step:514/1670 train_time:49873ms step_avg:97.03ms +step:515/1670 train_time:49968ms step_avg:97.03ms +step:516/1670 train_time:50064ms step_avg:97.02ms +step:517/1670 train_time:50159ms step_avg:97.02ms +step:518/1670 train_time:50254ms step_avg:97.01ms +step:519/1670 train_time:50348ms step_avg:97.01ms +step:520/1670 train_time:50443ms step_avg:97.01ms +step:521/1670 train_time:50537ms step_avg:97.00ms +step:522/1670 train_time:50634ms step_avg:97.00ms +step:523/1670 train_time:50731ms step_avg:97.00ms +step:524/1670 train_time:50828ms step_avg:97.00ms +step:525/1670 train_time:50925ms step_avg:97.00ms +step:526/1670 train_time:51021ms step_avg:97.00ms +step:527/1670 train_time:51116ms step_avg:97.00ms +step:528/1670 train_time:51211ms step_avg:96.99ms +step:529/1670 train_time:51306ms step_avg:96.99ms +step:530/1670 train_time:51401ms step_avg:96.98ms +step:531/1670 train_time:51496ms step_avg:96.98ms +step:532/1670 train_time:51591ms step_avg:96.98ms +step:533/1670 train_time:51688ms step_avg:96.98ms +step:534/1670 train_time:51785ms step_avg:96.98ms +step:535/1670 train_time:51880ms step_avg:96.97ms +step:536/1670 train_time:51976ms step_avg:96.97ms +step:537/1670 train_time:52072ms step_avg:96.97ms +step:538/1670 train_time:52168ms step_avg:96.97ms +step:539/1670 train_time:52264ms step_avg:96.96ms +step:540/1670 train_time:52359ms step_avg:96.96ms +step:541/1670 train_time:52453ms step_avg:96.96ms +step:542/1670 train_time:52549ms step_avg:96.95ms +step:543/1670 train_time:52645ms step_avg:96.95ms +step:544/1670 train_time:52740ms step_avg:96.95ms +step:545/1670 train_time:52836ms step_avg:96.95ms +step:546/1670 train_time:52932ms step_avg:96.95ms +step:547/1670 train_time:53029ms step_avg:96.95ms +step:548/1670 train_time:53125ms step_avg:96.94ms +step:549/1670 train_time:53221ms step_avg:96.94ms +step:550/1670 train_time:53316ms step_avg:96.94ms +step:551/1670 train_time:53411ms step_avg:96.93ms +step:552/1670 train_time:53506ms step_avg:96.93ms +step:553/1670 train_time:53602ms step_avg:96.93ms +step:554/1670 train_time:53698ms step_avg:96.93ms +step:555/1670 train_time:53794ms step_avg:96.93ms +step:556/1670 train_time:53890ms step_avg:96.92ms +step:557/1670 train_time:53986ms step_avg:96.92ms +step:558/1670 train_time:54082ms step_avg:96.92ms +step:559/1670 train_time:54179ms step_avg:96.92ms +step:560/1670 train_time:54276ms step_avg:96.92ms +step:561/1670 train_time:54372ms step_avg:96.92ms +step:562/1670 train_time:54469ms step_avg:96.92ms +step:563/1670 train_time:54565ms step_avg:96.92ms +step:564/1670 train_time:54663ms step_avg:96.92ms +step:565/1670 train_time:54759ms step_avg:96.92ms +step:566/1670 train_time:54856ms step_avg:96.92ms +step:567/1670 train_time:54954ms step_avg:96.92ms +step:568/1670 train_time:55052ms step_avg:96.92ms +step:569/1670 train_time:55150ms step_avg:96.93ms +step:570/1670 train_time:55249ms step_avg:96.93ms +step:571/1670 train_time:55346ms step_avg:96.93ms +step:572/1670 train_time:55443ms step_avg:96.93ms +step:573/1670 train_time:55540ms step_avg:96.93ms +step:574/1670 train_time:55636ms step_avg:96.93ms +step:575/1670 train_time:55732ms step_avg:96.93ms +step:576/1670 train_time:55830ms step_avg:96.93ms +step:577/1670 train_time:55927ms step_avg:96.93ms +step:578/1670 train_time:56024ms step_avg:96.93ms +step:579/1670 train_time:56122ms step_avg:96.93ms +step:580/1670 train_time:56218ms step_avg:96.93ms +step:581/1670 train_time:56315ms step_avg:96.93ms +step:582/1670 train_time:56413ms step_avg:96.93ms +step:583/1670 train_time:56510ms step_avg:96.93ms +step:584/1670 train_time:56608ms step_avg:96.93ms +step:585/1670 train_time:56705ms step_avg:96.93ms +step:586/1670 train_time:56802ms step_avg:96.93ms +step:587/1670 train_time:56900ms step_avg:96.93ms +step:588/1670 train_time:56996ms step_avg:96.93ms +step:589/1670 train_time:57093ms step_avg:96.93ms +step:590/1670 train_time:57191ms step_avg:96.93ms +step:591/1670 train_time:57289ms step_avg:96.94ms +step:592/1670 train_time:57386ms step_avg:96.93ms +step:593/1670 train_time:57483ms step_avg:96.94ms +step:594/1670 train_time:57580ms step_avg:96.94ms +step:595/1670 train_time:57676ms step_avg:96.93ms +step:596/1670 train_time:57773ms step_avg:96.93ms +step:597/1670 train_time:57871ms step_avg:96.94ms +step:598/1670 train_time:57969ms step_avg:96.94ms +step:599/1670 train_time:58066ms step_avg:96.94ms +step:600/1670 train_time:58163ms step_avg:96.94ms +step:601/1670 train_time:58260ms step_avg:96.94ms +step:602/1670 train_time:58356ms step_avg:96.94ms +step:603/1670 train_time:58454ms step_avg:96.94ms +step:604/1670 train_time:58552ms step_avg:96.94ms +step:605/1670 train_time:58650ms step_avg:96.94ms +step:606/1670 train_time:58748ms step_avg:96.94ms +step:607/1670 train_time:58845ms step_avg:96.94ms +step:608/1670 train_time:58942ms step_avg:96.94ms +step:609/1670 train_time:59038ms step_avg:96.94ms +step:610/1670 train_time:59135ms step_avg:96.94ms +step:611/1670 train_time:59232ms step_avg:96.94ms +step:612/1670 train_time:59330ms step_avg:96.94ms +step:613/1670 train_time:59428ms step_avg:96.95ms +step:614/1670 train_time:59525ms step_avg:96.95ms +step:615/1670 train_time:59622ms step_avg:96.95ms +step:616/1670 train_time:59719ms step_avg:96.95ms +step:617/1670 train_time:59816ms step_avg:96.95ms +step:618/1670 train_time:59913ms step_avg:96.95ms +step:619/1670 train_time:60010ms step_avg:96.95ms +step:620/1670 train_time:60107ms step_avg:96.95ms +step:621/1670 train_time:60205ms step_avg:96.95ms +step:622/1670 train_time:60302ms step_avg:96.95ms +step:623/1670 train_time:60398ms step_avg:96.95ms +step:624/1670 train_time:60496ms step_avg:96.95ms +step:625/1670 train_time:60593ms step_avg:96.95ms +step:625/1670 val_loss:3.6224 train_time:60691ms step_avg:97.11ms +step:626/1670 train_time:60717ms step_avg:96.99ms +step:627/1670 train_time:60794ms step_avg:96.96ms +step:628/1670 train_time:60891ms step_avg:96.96ms +step:629/1670 train_time:60987ms step_avg:96.96ms +step:630/1670 train_time:61082ms step_avg:96.96ms +step:631/1670 train_time:61177ms step_avg:96.95ms +step:632/1670 train_time:61273ms step_avg:96.95ms +step:633/1670 train_time:61369ms step_avg:96.95ms +step:634/1670 train_time:61465ms step_avg:96.95ms +step:635/1670 train_time:61561ms step_avg:96.95ms +step:636/1670 train_time:61661ms step_avg:96.95ms +step:637/1670 train_time:61761ms step_avg:96.96ms +step:638/1670 train_time:61859ms step_avg:96.96ms +step:639/1670 train_time:62129ms step_avg:97.23ms +step:640/1670 train_time:62324ms step_avg:97.38ms +step:641/1670 train_time:62419ms step_avg:97.38ms +step:642/1670 train_time:62515ms step_avg:97.37ms +step:643/1670 train_time:62610ms step_avg:97.37ms +step:644/1670 train_time:62706ms step_avg:97.37ms +step:645/1670 train_time:62802ms step_avg:97.37ms +step:646/1670 train_time:62897ms step_avg:97.36ms +step:647/1670 train_time:62993ms step_avg:97.36ms +step:648/1670 train_time:63089ms step_avg:97.36ms +step:649/1670 train_time:63187ms step_avg:97.36ms +step:650/1670 train_time:63291ms step_avg:97.37ms +step:651/1670 train_time:63392ms step_avg:97.38ms +step:652/1670 train_time:63490ms step_avg:97.38ms +step:653/1670 train_time:63587ms step_avg:97.38ms +step:654/1670 train_time:63683ms step_avg:97.38ms +step:655/1670 train_time:63779ms step_avg:97.37ms +step:656/1670 train_time:63875ms step_avg:97.37ms +step:657/1670 train_time:63971ms step_avg:97.37ms +step:658/1670 train_time:64068ms step_avg:97.37ms +step:659/1670 train_time:64165ms step_avg:97.37ms +step:660/1670 train_time:64264ms step_avg:97.37ms +step:661/1670 train_time:64362ms step_avg:97.37ms +step:662/1670 train_time:64461ms step_avg:97.37ms +step:663/1670 train_time:64558ms step_avg:97.37ms +step:664/1670 train_time:64654ms step_avg:97.37ms +step:665/1670 train_time:64751ms step_avg:97.37ms +step:666/1670 train_time:64847ms step_avg:97.37ms +step:667/1670 train_time:64944ms step_avg:97.37ms +step:668/1670 train_time:65041ms step_avg:97.37ms +step:669/1670 train_time:65137ms step_avg:97.36ms +step:670/1670 train_time:65235ms step_avg:97.36ms +step:671/1670 train_time:65332ms step_avg:97.37ms +step:672/1670 train_time:65431ms step_avg:97.37ms +step:673/1670 train_time:65528ms step_avg:97.37ms +step:674/1670 train_time:65627ms step_avg:97.37ms +step:675/1670 train_time:65724ms step_avg:97.37ms +step:676/1670 train_time:65820ms step_avg:97.37ms +step:677/1670 train_time:65916ms step_avg:97.37ms +step:678/1670 train_time:66012ms step_avg:97.36ms +step:679/1670 train_time:66110ms step_avg:97.36ms +step:680/1670 train_time:66207ms step_avg:97.36ms +step:681/1670 train_time:66304ms step_avg:97.36ms +step:682/1670 train_time:66402ms step_avg:97.36ms +step:683/1670 train_time:66499ms step_avg:97.36ms +step:684/1670 train_time:66596ms step_avg:97.36ms +step:685/1670 train_time:66692ms step_avg:97.36ms +step:686/1670 train_time:66790ms step_avg:97.36ms +step:687/1670 train_time:66888ms step_avg:97.36ms +step:688/1670 train_time:66984ms step_avg:97.36ms +step:689/1670 train_time:67081ms step_avg:97.36ms +step:690/1670 train_time:67178ms step_avg:97.36ms +step:691/1670 train_time:67275ms step_avg:97.36ms +step:692/1670 train_time:67371ms step_avg:97.36ms +step:693/1670 train_time:67470ms step_avg:97.36ms +step:694/1670 train_time:67568ms step_avg:97.36ms +step:695/1670 train_time:67665ms step_avg:97.36ms +step:696/1670 train_time:67763ms step_avg:97.36ms +step:697/1670 train_time:67859ms step_avg:97.36ms +step:698/1670 train_time:67956ms step_avg:97.36ms +step:699/1670 train_time:68052ms step_avg:97.36ms +step:700/1670 train_time:68150ms step_avg:97.36ms +step:701/1670 train_time:68248ms step_avg:97.36ms +step:702/1670 train_time:68346ms step_avg:97.36ms +step:703/1670 train_time:68443ms step_avg:97.36ms +step:704/1670 train_time:68541ms step_avg:97.36ms +step:705/1670 train_time:68638ms step_avg:97.36ms +step:706/1670 train_time:68735ms step_avg:97.36ms +step:707/1670 train_time:68832ms step_avg:97.36ms +step:708/1670 train_time:68929ms step_avg:97.36ms +step:709/1670 train_time:69026ms step_avg:97.36ms +step:710/1670 train_time:69123ms step_avg:97.36ms +step:711/1670 train_time:69220ms step_avg:97.36ms +step:712/1670 train_time:69317ms step_avg:97.36ms +step:713/1670 train_time:69414ms step_avg:97.35ms +step:714/1670 train_time:69511ms step_avg:97.35ms +step:715/1670 train_time:69609ms step_avg:97.35ms +step:716/1670 train_time:69707ms step_avg:97.36ms +step:717/1670 train_time:69803ms step_avg:97.35ms +step:718/1670 train_time:69900ms step_avg:97.35ms +step:719/1670 train_time:69997ms step_avg:97.35ms +step:720/1670 train_time:70092ms step_avg:97.35ms +step:721/1670 train_time:70190ms step_avg:97.35ms +step:722/1670 train_time:70288ms step_avg:97.35ms +step:723/1670 train_time:70385ms step_avg:97.35ms +step:724/1670 train_time:70482ms step_avg:97.35ms +step:725/1670 train_time:70580ms step_avg:97.35ms +step:726/1670 train_time:70677ms step_avg:97.35ms +step:727/1670 train_time:70774ms step_avg:97.35ms +step:728/1670 train_time:70872ms step_avg:97.35ms +step:729/1670 train_time:70969ms step_avg:97.35ms +step:730/1670 train_time:71066ms step_avg:97.35ms +step:731/1670 train_time:71163ms step_avg:97.35ms +step:732/1670 train_time:71261ms step_avg:97.35ms +step:733/1670 train_time:71357ms step_avg:97.35ms +step:734/1670 train_time:71453ms step_avg:97.35ms +step:735/1670 train_time:71551ms step_avg:97.35ms +step:736/1670 train_time:71648ms step_avg:97.35ms +step:737/1670 train_time:71745ms step_avg:97.35ms +step:738/1670 train_time:71842ms step_avg:97.35ms +step:739/1670 train_time:71940ms step_avg:97.35ms +step:740/1670 train_time:72037ms step_avg:97.35ms +step:741/1670 train_time:72133ms step_avg:97.35ms +step:742/1670 train_time:72230ms step_avg:97.35ms +step:743/1670 train_time:72328ms step_avg:97.35ms +step:744/1670 train_time:72425ms step_avg:97.35ms +step:745/1670 train_time:72523ms step_avg:97.35ms +step:746/1670 train_time:72619ms step_avg:97.35ms +step:747/1670 train_time:72716ms step_avg:97.34ms +step:748/1670 train_time:72814ms step_avg:97.34ms +step:749/1670 train_time:72910ms step_avg:97.34ms +step:750/1670 train_time:73008ms step_avg:97.34ms +step:750/1670 val_loss:3.5662 train_time:73105ms step_avg:97.47ms +step:751/1670 train_time:73129ms step_avg:97.38ms +step:752/1670 train_time:73210ms step_avg:97.35ms +step:753/1670 train_time:73310ms step_avg:97.36ms +step:754/1670 train_time:73407ms step_avg:97.36ms +step:755/1670 train_time:73503ms step_avg:97.35ms +step:756/1670 train_time:73599ms step_avg:97.35ms +step:757/1670 train_time:73695ms step_avg:97.35ms +step:758/1670 train_time:73791ms step_avg:97.35ms +step:759/1670 train_time:73888ms step_avg:97.35ms +step:760/1670 train_time:73983ms step_avg:97.35ms +step:761/1670 train_time:74081ms step_avg:97.35ms +step:762/1670 train_time:74181ms step_avg:97.35ms +step:763/1670 train_time:74279ms step_avg:97.35ms +step:764/1670 train_time:74377ms step_avg:97.35ms +step:765/1670 train_time:74473ms step_avg:97.35ms +step:766/1670 train_time:74569ms step_avg:97.35ms +step:767/1670 train_time:74666ms step_avg:97.35ms +step:768/1670 train_time:74763ms step_avg:97.35ms +step:769/1670 train_time:74860ms step_avg:97.35ms +step:770/1670 train_time:74956ms step_avg:97.35ms +step:771/1670 train_time:75053ms step_avg:97.34ms +step:772/1670 train_time:75150ms step_avg:97.34ms +step:773/1670 train_time:75249ms step_avg:97.35ms +step:774/1670 train_time:75347ms step_avg:97.35ms +step:775/1670 train_time:75444ms step_avg:97.35ms +step:776/1670 train_time:75541ms step_avg:97.35ms +step:777/1670 train_time:75637ms step_avg:97.35ms +step:778/1670 train_time:75733ms step_avg:97.34ms +step:779/1670 train_time:75830ms step_avg:97.34ms +step:780/1670 train_time:75927ms step_avg:97.34ms +step:781/1670 train_time:76025ms step_avg:97.34ms +step:782/1670 train_time:76123ms step_avg:97.34ms +step:783/1670 train_time:76221ms step_avg:97.34ms +step:784/1670 train_time:76318ms step_avg:97.34ms +step:785/1670 train_time:76414ms step_avg:97.34ms +step:786/1670 train_time:76511ms step_avg:97.34ms +step:787/1670 train_time:76610ms step_avg:97.34ms +step:788/1670 train_time:76707ms step_avg:97.34ms +step:789/1670 train_time:76804ms step_avg:97.34ms +step:790/1670 train_time:76901ms step_avg:97.34ms +step:791/1670 train_time:76998ms step_avg:97.34ms +step:792/1670 train_time:77094ms step_avg:97.34ms +step:793/1670 train_time:77191ms step_avg:97.34ms +step:794/1670 train_time:77289ms step_avg:97.34ms +step:795/1670 train_time:77386ms step_avg:97.34ms +step:796/1670 train_time:77484ms step_avg:97.34ms +step:797/1670 train_time:77581ms step_avg:97.34ms +step:798/1670 train_time:77678ms step_avg:97.34ms +step:799/1670 train_time:77774ms step_avg:97.34ms +step:800/1670 train_time:77870ms step_avg:97.34ms +step:801/1670 train_time:77969ms step_avg:97.34ms +step:802/1670 train_time:78068ms step_avg:97.34ms +step:803/1670 train_time:78165ms step_avg:97.34ms +step:804/1670 train_time:78263ms step_avg:97.34ms +step:805/1670 train_time:78360ms step_avg:97.34ms +step:806/1670 train_time:78457ms step_avg:97.34ms +step:807/1670 train_time:78555ms step_avg:97.34ms +step:808/1670 train_time:78651ms step_avg:97.34ms +step:809/1670 train_time:78748ms step_avg:97.34ms +step:810/1670 train_time:78846ms step_avg:97.34ms +step:811/1670 train_time:78942ms step_avg:97.34ms +step:812/1670 train_time:79040ms step_avg:97.34ms +step:813/1670 train_time:79136ms step_avg:97.34ms +step:814/1670 train_time:79234ms step_avg:97.34ms +step:815/1670 train_time:79332ms step_avg:97.34ms +step:816/1670 train_time:79429ms step_avg:97.34ms +step:817/1670 train_time:79526ms step_avg:97.34ms +step:818/1670 train_time:79624ms step_avg:97.34ms +step:819/1670 train_time:79720ms step_avg:97.34ms +step:820/1670 train_time:79817ms step_avg:97.34ms +step:821/1670 train_time:79914ms step_avg:97.34ms +step:822/1670 train_time:80011ms step_avg:97.34ms +step:823/1670 train_time:80109ms step_avg:97.34ms +step:824/1670 train_time:80206ms step_avg:97.34ms +step:825/1670 train_time:80303ms step_avg:97.34ms +step:826/1670 train_time:80400ms step_avg:97.34ms +step:827/1670 train_time:80497ms step_avg:97.34ms +step:828/1670 train_time:80594ms step_avg:97.34ms +step:829/1670 train_time:80690ms step_avg:97.33ms +step:830/1670 train_time:80787ms step_avg:97.33ms +step:831/1670 train_time:80884ms step_avg:97.33ms +step:832/1670 train_time:80981ms step_avg:97.33ms +step:833/1670 train_time:81079ms step_avg:97.33ms +step:834/1670 train_time:81176ms step_avg:97.33ms +step:835/1670 train_time:81273ms step_avg:97.33ms +step:836/1670 train_time:81370ms step_avg:97.33ms +step:837/1670 train_time:81467ms step_avg:97.33ms +step:838/1670 train_time:81565ms step_avg:97.33ms +step:839/1670 train_time:81662ms step_avg:97.33ms +step:840/1670 train_time:81759ms step_avg:97.33ms +step:841/1670 train_time:81855ms step_avg:97.33ms +step:842/1670 train_time:81952ms step_avg:97.33ms +step:843/1670 train_time:82050ms step_avg:97.33ms +step:844/1670 train_time:82147ms step_avg:97.33ms +step:845/1670 train_time:82245ms step_avg:97.33ms +step:846/1670 train_time:82343ms step_avg:97.33ms +step:847/1670 train_time:82441ms step_avg:97.33ms +step:848/1670 train_time:82537ms step_avg:97.33ms +step:849/1670 train_time:82634ms step_avg:97.33ms +step:850/1670 train_time:82731ms step_avg:97.33ms +step:851/1670 train_time:83001ms step_avg:97.53ms +step:852/1670 train_time:83142ms step_avg:97.58ms +step:853/1670 train_time:83238ms step_avg:97.58ms +step:854/1670 train_time:83333ms step_avg:97.58ms +step:855/1670 train_time:83429ms step_avg:97.58ms +step:856/1670 train_time:83524ms step_avg:97.57ms +step:857/1670 train_time:83620ms step_avg:97.57ms +step:858/1670 train_time:83716ms step_avg:97.57ms +step:859/1670 train_time:83812ms step_avg:97.57ms +step:860/1670 train_time:83908ms step_avg:97.57ms +step:861/1670 train_time:84007ms step_avg:97.57ms +step:862/1670 train_time:84112ms step_avg:97.58ms +step:863/1670 train_time:84210ms step_avg:97.58ms +step:864/1670 train_time:84308ms step_avg:97.58ms +step:865/1670 train_time:84405ms step_avg:97.58ms +step:866/1670 train_time:84501ms step_avg:97.58ms +step:867/1670 train_time:84597ms step_avg:97.57ms +step:868/1670 train_time:84692ms step_avg:97.57ms +step:869/1670 train_time:84788ms step_avg:97.57ms +step:870/1670 train_time:84884ms step_avg:97.57ms +step:871/1670 train_time:84982ms step_avg:97.57ms +step:872/1670 train_time:85081ms step_avg:97.57ms +step:873/1670 train_time:85179ms step_avg:97.57ms +step:874/1670 train_time:85277ms step_avg:97.57ms +step:875/1670 train_time:85373ms step_avg:97.57ms +step:875/1670 val_loss:3.5252 train_time:85468ms step_avg:97.68ms +step:876/1670 train_time:85492ms step_avg:97.59ms +step:877/1670 train_time:85572ms step_avg:97.57ms +step:878/1670 train_time:85671ms step_avg:97.58ms +step:879/1670 train_time:85768ms step_avg:97.57ms +step:880/1670 train_time:85863ms step_avg:97.57ms +step:881/1670 train_time:85959ms step_avg:97.57ms +step:882/1670 train_time:86055ms step_avg:97.57ms +step:883/1670 train_time:86150ms step_avg:97.57ms +step:884/1670 train_time:86247ms step_avg:97.56ms +step:885/1670 train_time:86343ms step_avg:97.56ms +step:886/1670 train_time:86441ms step_avg:97.56ms +step:887/1670 train_time:86540ms step_avg:97.57ms +step:888/1670 train_time:86639ms step_avg:97.57ms +step:889/1670 train_time:86736ms step_avg:97.57ms +step:890/1670 train_time:86834ms step_avg:97.57ms +step:891/1670 train_time:86930ms step_avg:97.56ms +step:892/1670 train_time:87026ms step_avg:97.56ms +step:893/1670 train_time:87121ms step_avg:97.56ms +step:894/1670 train_time:87217ms step_avg:97.56ms +step:895/1670 train_time:87315ms step_avg:97.56ms +step:896/1670 train_time:87413ms step_avg:97.56ms +step:897/1670 train_time:87512ms step_avg:97.56ms +step:898/1670 train_time:87610ms step_avg:97.56ms +step:899/1670 train_time:87708ms step_avg:97.56ms +step:900/1670 train_time:87805ms step_avg:97.56ms +step:901/1670 train_time:87902ms step_avg:97.56ms +step:902/1670 train_time:87998ms step_avg:97.56ms +step:903/1670 train_time:88095ms step_avg:97.56ms +step:904/1670 train_time:88192ms step_avg:97.56ms +step:905/1670 train_time:88289ms step_avg:97.56ms +step:906/1670 train_time:88386ms step_avg:97.56ms +step:907/1670 train_time:88483ms step_avg:97.56ms +step:908/1670 train_time:88580ms step_avg:97.56ms +step:909/1670 train_time:88678ms step_avg:97.56ms +step:910/1670 train_time:88778ms step_avg:97.56ms +step:911/1670 train_time:88876ms step_avg:97.56ms +step:912/1670 train_time:88972ms step_avg:97.56ms +step:913/1670 train_time:89069ms step_avg:97.56ms +step:914/1670 train_time:89166ms step_avg:97.56ms +step:915/1670 train_time:89262ms step_avg:97.55ms +step:916/1670 train_time:89358ms step_avg:97.55ms +step:917/1670 train_time:89457ms step_avg:97.55ms +step:918/1670 train_time:89555ms step_avg:97.55ms +step:919/1670 train_time:89653ms step_avg:97.56ms +step:920/1670 train_time:89751ms step_avg:97.56ms +step:921/1670 train_time:89849ms step_avg:97.56ms +step:922/1670 train_time:89946ms step_avg:97.56ms +step:923/1670 train_time:90043ms step_avg:97.55ms +step:924/1670 train_time:90141ms step_avg:97.55ms +step:925/1670 train_time:90237ms step_avg:97.55ms +step:926/1670 train_time:90334ms step_avg:97.55ms +step:927/1670 train_time:90432ms step_avg:97.55ms +step:928/1670 train_time:90529ms step_avg:97.55ms +step:929/1670 train_time:90626ms step_avg:97.55ms +step:930/1670 train_time:90724ms step_avg:97.55ms +step:931/1670 train_time:90821ms step_avg:97.55ms +step:932/1670 train_time:90918ms step_avg:97.55ms +step:933/1670 train_time:91015ms step_avg:97.55ms +step:934/1670 train_time:91112ms step_avg:97.55ms +step:935/1670 train_time:91209ms step_avg:97.55ms +step:936/1670 train_time:91306ms step_avg:97.55ms +step:937/1670 train_time:91402ms step_avg:97.55ms +step:938/1670 train_time:91499ms step_avg:97.55ms +step:939/1670 train_time:91596ms step_avg:97.55ms +step:940/1670 train_time:91694ms step_avg:97.55ms +step:941/1670 train_time:91792ms step_avg:97.55ms +step:942/1670 train_time:91889ms step_avg:97.55ms +step:943/1670 train_time:91986ms step_avg:97.55ms +step:944/1670 train_time:92082ms step_avg:97.54ms +step:945/1670 train_time:92179ms step_avg:97.54ms +step:946/1670 train_time:92277ms step_avg:97.54ms +step:947/1670 train_time:92376ms step_avg:97.55ms +step:948/1670 train_time:92472ms step_avg:97.54ms +step:949/1670 train_time:92571ms step_avg:97.55ms +step:950/1670 train_time:92668ms step_avg:97.55ms +step:951/1670 train_time:92765ms step_avg:97.54ms +step:952/1670 train_time:92862ms step_avg:97.54ms +step:953/1670 train_time:92958ms step_avg:97.54ms +step:954/1670 train_time:93056ms step_avg:97.54ms +step:955/1670 train_time:93153ms step_avg:97.54ms +step:956/1670 train_time:93250ms step_avg:97.54ms +step:957/1670 train_time:93347ms step_avg:97.54ms +step:958/1670 train_time:93444ms step_avg:97.54ms +step:959/1670 train_time:93541ms step_avg:97.54ms +step:960/1670 train_time:93638ms step_avg:97.54ms +step:961/1670 train_time:93736ms step_avg:97.54ms +step:962/1670 train_time:93834ms step_avg:97.54ms +step:963/1670 train_time:93931ms step_avg:97.54ms +step:964/1670 train_time:94028ms step_avg:97.54ms +step:965/1670 train_time:94126ms step_avg:97.54ms +step:966/1670 train_time:94222ms step_avg:97.54ms +step:967/1670 train_time:94319ms step_avg:97.54ms +step:968/1670 train_time:94418ms step_avg:97.54ms +step:969/1670 train_time:94515ms step_avg:97.54ms +step:970/1670 train_time:94612ms step_avg:97.54ms +step:971/1670 train_time:94711ms step_avg:97.54ms +step:972/1670 train_time:94808ms step_avg:97.54ms +step:973/1670 train_time:94905ms step_avg:97.54ms +step:974/1670 train_time:95002ms step_avg:97.54ms +step:975/1670 train_time:95099ms step_avg:97.54ms +step:976/1670 train_time:95197ms step_avg:97.54ms +step:977/1670 train_time:95295ms step_avg:97.54ms +step:978/1670 train_time:95392ms step_avg:97.54ms +step:979/1670 train_time:95489ms step_avg:97.54ms +step:980/1670 train_time:95586ms step_avg:97.54ms +step:981/1670 train_time:95682ms step_avg:97.54ms +step:982/1670 train_time:95779ms step_avg:97.53ms +step:983/1670 train_time:95878ms step_avg:97.54ms +step:984/1670 train_time:95975ms step_avg:97.54ms +step:985/1670 train_time:96072ms step_avg:97.53ms +step:986/1670 train_time:96170ms step_avg:97.54ms +step:987/1670 train_time:96268ms step_avg:97.54ms +step:988/1670 train_time:96365ms step_avg:97.54ms +step:989/1670 train_time:96462ms step_avg:97.53ms +step:990/1670 train_time:96558ms step_avg:97.53ms +step:991/1670 train_time:96656ms step_avg:97.53ms +step:992/1670 train_time:96754ms step_avg:97.53ms +step:993/1670 train_time:96852ms step_avg:97.53ms +step:994/1670 train_time:96950ms step_avg:97.53ms +step:995/1670 train_time:97046ms step_avg:97.53ms +step:996/1670 train_time:97142ms step_avg:97.53ms +step:997/1670 train_time:97239ms step_avg:97.53ms +step:998/1670 train_time:97337ms step_avg:97.53ms +step:999/1670 train_time:97435ms step_avg:97.53ms +step:1000/1670 train_time:97533ms step_avg:97.53ms +step:1000/1670 val_loss:3.4801 train_time:97629ms step_avg:97.63ms +step:1001/1670 train_time:97658ms step_avg:97.56ms +step:1002/1670 train_time:97734ms step_avg:97.54ms +step:1003/1670 train_time:97835ms step_avg:97.54ms +step:1004/1670 train_time:97931ms step_avg:97.54ms +step:1005/1670 train_time:98027ms step_avg:97.54ms +step:1006/1670 train_time:98123ms step_avg:97.54ms +step:1007/1670 train_time:98219ms step_avg:97.54ms +step:1008/1670 train_time:98315ms step_avg:97.53ms +step:1009/1670 train_time:98410ms step_avg:97.53ms +step:1010/1670 train_time:98506ms step_avg:97.53ms +step:1011/1670 train_time:98602ms step_avg:97.53ms +step:1012/1670 train_time:98702ms step_avg:97.53ms +step:1013/1670 train_time:98801ms step_avg:97.53ms +step:1014/1670 train_time:98899ms step_avg:97.53ms +step:1015/1670 train_time:98996ms step_avg:97.53ms +step:1016/1670 train_time:99094ms step_avg:97.53ms +step:1017/1670 train_time:99191ms step_avg:97.53ms +step:1018/1670 train_time:99286ms step_avg:97.53ms +step:1019/1670 train_time:99382ms step_avg:97.53ms +step:1020/1670 train_time:99478ms step_avg:97.53ms +step:1021/1670 train_time:99575ms step_avg:97.53ms +step:1022/1670 train_time:99673ms step_avg:97.53ms +step:1023/1670 train_time:99771ms step_avg:97.53ms +step:1024/1670 train_time:99868ms step_avg:97.53ms +step:1025/1670 train_time:99965ms step_avg:97.53ms +step:1026/1670 train_time:100063ms step_avg:97.53ms +step:1027/1670 train_time:100160ms step_avg:97.53ms +step:1028/1670 train_time:100258ms step_avg:97.53ms +step:1029/1670 train_time:100354ms step_avg:97.53ms +step:1030/1670 train_time:100451ms step_avg:97.52ms +step:1031/1670 train_time:100548ms step_avg:97.52ms +step:1032/1670 train_time:100644ms step_avg:97.52ms +step:1033/1670 train_time:100740ms step_avg:97.52ms +step:1034/1670 train_time:100838ms step_avg:97.52ms +step:1035/1670 train_time:100935ms step_avg:97.52ms +step:1036/1670 train_time:101033ms step_avg:97.52ms +step:1037/1670 train_time:101131ms step_avg:97.52ms +step:1038/1670 train_time:101228ms step_avg:97.52ms +step:1039/1670 train_time:101324ms step_avg:97.52ms +step:1040/1670 train_time:101420ms step_avg:97.52ms +step:1041/1670 train_time:101517ms step_avg:97.52ms +step:1042/1670 train_time:101615ms step_avg:97.52ms +step:1043/1670 train_time:101712ms step_avg:97.52ms +step:1044/1670 train_time:101809ms step_avg:97.52ms +step:1045/1670 train_time:101905ms step_avg:97.52ms +step:1046/1670 train_time:102003ms step_avg:97.52ms +step:1047/1670 train_time:102100ms step_avg:97.52ms +step:1048/1670 train_time:102198ms step_avg:97.52ms +step:1049/1670 train_time:102296ms step_avg:97.52ms +step:1050/1670 train_time:102393ms step_avg:97.52ms +step:1051/1670 train_time:102490ms step_avg:97.52ms +step:1052/1670 train_time:102587ms step_avg:97.52ms +step:1053/1670 train_time:102684ms step_avg:97.52ms +step:1054/1670 train_time:102781ms step_avg:97.52ms +step:1055/1670 train_time:102880ms step_avg:97.52ms +step:1056/1670 train_time:102977ms step_avg:97.52ms +step:1057/1670 train_time:103075ms step_avg:97.52ms +step:1058/1670 train_time:103172ms step_avg:97.52ms +step:1059/1670 train_time:103269ms step_avg:97.52ms +step:1060/1670 train_time:103365ms step_avg:97.51ms +step:1061/1670 train_time:103462ms step_avg:97.51ms +step:1062/1670 train_time:103731ms step_avg:97.67ms +step:1063/1670 train_time:103969ms step_avg:97.81ms +step:1064/1670 train_time:104063ms step_avg:97.80ms +step:1065/1670 train_time:104159ms step_avg:97.80ms +step:1066/1670 train_time:104255ms step_avg:97.80ms +step:1067/1670 train_time:104350ms step_avg:97.80ms +step:1068/1670 train_time:104445ms step_avg:97.79ms +step:1069/1670 train_time:104541ms step_avg:97.79ms +step:1070/1670 train_time:104637ms step_avg:97.79ms +step:1071/1670 train_time:104733ms step_avg:97.79ms +step:1072/1670 train_time:104835ms step_avg:97.79ms +step:1073/1670 train_time:104935ms step_avg:97.80ms +step:1074/1670 train_time:105036ms step_avg:97.80ms +step:1075/1670 train_time:105134ms step_avg:97.80ms +step:1076/1670 train_time:105230ms step_avg:97.80ms +step:1077/1670 train_time:105327ms step_avg:97.80ms +step:1078/1670 train_time:105422ms step_avg:97.79ms +step:1079/1670 train_time:105518ms step_avg:97.79ms +step:1080/1670 train_time:105614ms step_avg:97.79ms +step:1081/1670 train_time:105710ms step_avg:97.79ms +step:1082/1670 train_time:105809ms step_avg:97.79ms +step:1083/1670 train_time:105907ms step_avg:97.79ms +step:1084/1670 train_time:106005ms step_avg:97.79ms +step:1085/1670 train_time:106104ms step_avg:97.79ms +step:1086/1670 train_time:106202ms step_avg:97.79ms +step:1087/1670 train_time:106300ms step_avg:97.79ms +step:1088/1670 train_time:106396ms step_avg:97.79ms +step:1089/1670 train_time:106493ms step_avg:97.79ms +step:1090/1670 train_time:106589ms step_avg:97.79ms +step:1091/1670 train_time:106684ms step_avg:97.79ms +step:1092/1670 train_time:106781ms step_avg:97.78ms +step:1093/1670 train_time:106880ms step_avg:97.79ms +step:1094/1670 train_time:106978ms step_avg:97.79ms +step:1095/1670 train_time:107076ms step_avg:97.79ms +step:1096/1670 train_time:107175ms step_avg:97.79ms +step:1097/1670 train_time:107272ms step_avg:97.79ms +step:1098/1670 train_time:107369ms step_avg:97.79ms +step:1099/1670 train_time:107466ms step_avg:97.79ms +step:1100/1670 train_time:107562ms step_avg:97.78ms +step:1101/1670 train_time:107659ms step_avg:97.78ms +step:1102/1670 train_time:107755ms step_avg:97.78ms +step:1103/1670 train_time:107853ms step_avg:97.78ms +step:1104/1670 train_time:107949ms step_avg:97.78ms +step:1105/1670 train_time:108046ms step_avg:97.78ms +step:1106/1670 train_time:108144ms step_avg:97.78ms +step:1107/1670 train_time:108243ms step_avg:97.78ms +step:1108/1670 train_time:108341ms step_avg:97.78ms +step:1109/1670 train_time:108439ms step_avg:97.78ms +step:1110/1670 train_time:108537ms step_avg:97.78ms +step:1111/1670 train_time:108635ms step_avg:97.78ms +step:1112/1670 train_time:108731ms step_avg:97.78ms +step:1113/1670 train_time:108828ms step_avg:97.78ms +step:1114/1670 train_time:108925ms step_avg:97.78ms +step:1115/1670 train_time:109021ms step_avg:97.78ms +step:1116/1670 train_time:109119ms step_avg:97.78ms +step:1117/1670 train_time:109217ms step_avg:97.78ms +step:1118/1670 train_time:109315ms step_avg:97.78ms +step:1119/1670 train_time:109413ms step_avg:97.78ms +step:1120/1670 train_time:109510ms step_avg:97.78ms +step:1121/1670 train_time:109608ms step_avg:97.78ms +step:1122/1670 train_time:109706ms step_avg:97.78ms +step:1123/1670 train_time:109804ms step_avg:97.78ms +step:1124/1670 train_time:109902ms step_avg:97.78ms +step:1125/1670 train_time:110000ms step_avg:97.78ms +step:1125/1670 val_loss:3.4271 train_time:110097ms step_avg:97.86ms +step:1126/1670 train_time:110124ms step_avg:97.80ms +step:1127/1670 train_time:110208ms step_avg:97.79ms +step:1128/1670 train_time:110305ms step_avg:97.79ms +step:1129/1670 train_time:110402ms step_avg:97.79ms +step:1130/1670 train_time:110499ms step_avg:97.79ms +step:1131/1670 train_time:110595ms step_avg:97.79ms +step:1132/1670 train_time:110693ms step_avg:97.78ms +step:1133/1670 train_time:110788ms step_avg:97.78ms +step:1134/1670 train_time:110884ms step_avg:97.78ms +step:1135/1670 train_time:110981ms step_avg:97.78ms +step:1136/1670 train_time:111083ms step_avg:97.78ms +step:1137/1670 train_time:111185ms step_avg:97.79ms +step:1138/1670 train_time:111284ms step_avg:97.79ms +step:1139/1670 train_time:111383ms step_avg:97.79ms +step:1140/1670 train_time:111480ms step_avg:97.79ms +step:1141/1670 train_time:111577ms step_avg:97.79ms +step:1142/1670 train_time:111674ms step_avg:97.79ms +step:1143/1670 train_time:111771ms step_avg:97.79ms +step:1144/1670 train_time:111868ms step_avg:97.79ms +step:1145/1670 train_time:111965ms step_avg:97.79ms +step:1146/1670 train_time:112065ms step_avg:97.79ms +step:1147/1670 train_time:112165ms step_avg:97.79ms +step:1148/1670 train_time:112266ms step_avg:97.79ms +step:1149/1670 train_time:112364ms step_avg:97.79ms +step:1150/1670 train_time:112461ms step_avg:97.79ms +step:1151/1670 train_time:112557ms step_avg:97.79ms +step:1152/1670 train_time:112655ms step_avg:97.79ms +step:1153/1670 train_time:112752ms step_avg:97.79ms +step:1154/1670 train_time:112849ms step_avg:97.79ms +step:1155/1670 train_time:112946ms step_avg:97.79ms +step:1156/1670 train_time:113043ms step_avg:97.79ms +step:1157/1670 train_time:113142ms step_avg:97.79ms +step:1158/1670 train_time:113240ms step_avg:97.79ms +step:1159/1670 train_time:113339ms step_avg:97.79ms +step:1160/1670 train_time:113437ms step_avg:97.79ms +step:1161/1670 train_time:113536ms step_avg:97.79ms +step:1162/1670 train_time:113633ms step_avg:97.79ms +step:1163/1670 train_time:113729ms step_avg:97.79ms +step:1164/1670 train_time:113826ms step_avg:97.79ms +step:1165/1670 train_time:113924ms step_avg:97.79ms +step:1166/1670 train_time:114022ms step_avg:97.79ms +step:1167/1670 train_time:114120ms step_avg:97.79ms +step:1168/1670 train_time:114218ms step_avg:97.79ms +step:1169/1670 train_time:114316ms step_avg:97.79ms +step:1170/1670 train_time:114413ms step_avg:97.79ms +step:1171/1670 train_time:114510ms step_avg:97.79ms +step:1172/1670 train_time:114608ms step_avg:97.79ms +step:1173/1670 train_time:114706ms step_avg:97.79ms +step:1174/1670 train_time:114803ms step_avg:97.79ms +step:1175/1670 train_time:114902ms step_avg:97.79ms +step:1176/1670 train_time:115001ms step_avg:97.79ms +step:1177/1670 train_time:115097ms step_avg:97.79ms +step:1178/1670 train_time:115195ms step_avg:97.79ms +step:1179/1670 train_time:115293ms step_avg:97.79ms +step:1180/1670 train_time:115391ms step_avg:97.79ms +step:1181/1670 train_time:115488ms step_avg:97.79ms +step:1182/1670 train_time:115586ms step_avg:97.79ms +step:1183/1670 train_time:115683ms step_avg:97.79ms +step:1184/1670 train_time:115780ms step_avg:97.79ms +step:1185/1670 train_time:115879ms step_avg:97.79ms +step:1186/1670 train_time:115976ms step_avg:97.79ms +step:1187/1670 train_time:116074ms step_avg:97.79ms +step:1188/1670 train_time:116170ms step_avg:97.79ms +step:1189/1670 train_time:116269ms step_avg:97.79ms +step:1190/1670 train_time:116367ms step_avg:97.79ms +step:1191/1670 train_time:116465ms step_avg:97.79ms +step:1192/1670 train_time:116565ms step_avg:97.79ms +step:1193/1670 train_time:116663ms step_avg:97.79ms +step:1194/1670 train_time:116759ms step_avg:97.79ms +step:1195/1670 train_time:116857ms step_avg:97.79ms +step:1196/1670 train_time:116955ms step_avg:97.79ms +step:1197/1670 train_time:117052ms step_avg:97.79ms +step:1198/1670 train_time:117150ms step_avg:97.79ms +step:1199/1670 train_time:117247ms step_avg:97.79ms +step:1200/1670 train_time:117345ms step_avg:97.79ms +step:1201/1670 train_time:117444ms step_avg:97.79ms +step:1202/1670 train_time:117543ms step_avg:97.79ms +step:1203/1670 train_time:117643ms step_avg:97.79ms +step:1204/1670 train_time:117740ms step_avg:97.79ms +step:1205/1670 train_time:117838ms step_avg:97.79ms +step:1206/1670 train_time:117935ms step_avg:97.79ms +step:1207/1670 train_time:118033ms step_avg:97.79ms +step:1208/1670 train_time:118131ms step_avg:97.79ms +step:1209/1670 train_time:118228ms step_avg:97.79ms +step:1210/1670 train_time:118326ms step_avg:97.79ms +step:1211/1670 train_time:118424ms step_avg:97.79ms +step:1212/1670 train_time:118522ms step_avg:97.79ms +step:1213/1670 train_time:118620ms step_avg:97.79ms +step:1214/1670 train_time:118718ms step_avg:97.79ms +step:1215/1670 train_time:118815ms step_avg:97.79ms +step:1216/1670 train_time:118913ms step_avg:97.79ms +step:1217/1670 train_time:119010ms step_avg:97.79ms +step:1218/1670 train_time:119107ms step_avg:97.79ms +step:1219/1670 train_time:119204ms step_avg:97.79ms +step:1220/1670 train_time:119302ms step_avg:97.79ms +step:1221/1670 train_time:119401ms step_avg:97.79ms +step:1222/1670 train_time:119499ms step_avg:97.79ms +step:1223/1670 train_time:119597ms step_avg:97.79ms +step:1224/1670 train_time:119695ms step_avg:97.79ms +step:1225/1670 train_time:119791ms step_avg:97.79ms +step:1226/1670 train_time:119889ms step_avg:97.79ms +step:1227/1670 train_time:119986ms step_avg:97.79ms +step:1228/1670 train_time:120085ms step_avg:97.79ms +step:1229/1670 train_time:120182ms step_avg:97.79ms +step:1230/1670 train_time:120281ms step_avg:97.79ms +step:1231/1670 train_time:120380ms step_avg:97.79ms +step:1232/1670 train_time:120478ms step_avg:97.79ms +step:1233/1670 train_time:120575ms step_avg:97.79ms +step:1234/1670 train_time:120673ms step_avg:97.79ms +step:1235/1670 train_time:120771ms step_avg:97.79ms +step:1236/1670 train_time:120868ms step_avg:97.79ms +step:1237/1670 train_time:120965ms step_avg:97.79ms +step:1238/1670 train_time:121063ms step_avg:97.79ms +step:1239/1670 train_time:121161ms step_avg:97.79ms +step:1240/1670 train_time:121260ms step_avg:97.79ms +step:1241/1670 train_time:121358ms step_avg:97.79ms +step:1242/1670 train_time:121457ms step_avg:97.79ms +step:1243/1670 train_time:121555ms step_avg:97.79ms +step:1244/1670 train_time:121653ms step_avg:97.79ms +step:1245/1670 train_time:121750ms step_avg:97.79ms +step:1246/1670 train_time:121847ms step_avg:97.79ms +step:1247/1670 train_time:121945ms step_avg:97.79ms +step:1248/1670 train_time:122042ms step_avg:97.79ms +step:1249/1670 train_time:122141ms step_avg:97.79ms +step:1250/1670 train_time:122238ms step_avg:97.79ms +step:1250/1670 val_loss:3.3828 train_time:122335ms step_avg:97.87ms +step:1251/1670 train_time:122358ms step_avg:97.81ms +step:1252/1670 train_time:122439ms step_avg:97.79ms +step:1253/1670 train_time:122539ms step_avg:97.80ms +step:1254/1670 train_time:122638ms step_avg:97.80ms +step:1255/1670 train_time:122736ms step_avg:97.80ms +step:1256/1670 train_time:122833ms step_avg:97.80ms +step:1257/1670 train_time:122930ms step_avg:97.80ms +step:1258/1670 train_time:123027ms step_avg:97.80ms +step:1259/1670 train_time:123123ms step_avg:97.79ms +step:1260/1670 train_time:123220ms step_avg:97.79ms +step:1261/1670 train_time:123319ms step_avg:97.79ms +step:1262/1670 train_time:123421ms step_avg:97.80ms +step:1263/1670 train_time:123520ms step_avg:97.80ms +step:1264/1670 train_time:123617ms step_avg:97.80ms +step:1265/1670 train_time:123715ms step_avg:97.80ms +step:1266/1670 train_time:123812ms step_avg:97.80ms +step:1267/1670 train_time:123910ms step_avg:97.80ms +step:1268/1670 train_time:124007ms step_avg:97.80ms +step:1269/1670 train_time:124104ms step_avg:97.80ms +step:1270/1670 train_time:124201ms step_avg:97.80ms +step:1271/1670 train_time:124299ms step_avg:97.80ms +step:1272/1670 train_time:124399ms step_avg:97.80ms +step:1273/1670 train_time:124499ms step_avg:97.80ms +step:1274/1670 train_time:124768ms step_avg:97.93ms +step:1275/1670 train_time:124952ms step_avg:98.00ms +step:1276/1670 train_time:125047ms step_avg:98.00ms +step:1277/1670 train_time:125143ms step_avg:98.00ms +step:1278/1670 train_time:125239ms step_avg:98.00ms +step:1279/1670 train_time:125336ms step_avg:98.00ms +step:1280/1670 train_time:125433ms step_avg:97.99ms +step:1281/1670 train_time:125531ms step_avg:97.99ms +step:1282/1670 train_time:125628ms step_avg:97.99ms +step:1283/1670 train_time:125724ms step_avg:97.99ms +step:1284/1670 train_time:125823ms step_avg:97.99ms +step:1285/1670 train_time:125925ms step_avg:98.00ms +step:1286/1670 train_time:126024ms step_avg:98.00ms +step:1287/1670 train_time:126123ms step_avg:98.00ms +step:1288/1670 train_time:126220ms step_avg:98.00ms +step:1289/1670 train_time:126318ms step_avg:98.00ms +step:1290/1670 train_time:126414ms step_avg:98.00ms +step:1291/1670 train_time:126511ms step_avg:97.99ms +step:1292/1670 train_time:126608ms step_avg:97.99ms +step:1293/1670 train_time:126705ms step_avg:97.99ms +step:1294/1670 train_time:126803ms step_avg:97.99ms +step:1295/1670 train_time:126902ms step_avg:97.99ms +step:1296/1670 train_time:127001ms step_avg:97.99ms +step:1297/1670 train_time:127099ms step_avg:97.99ms +step:1298/1670 train_time:127197ms step_avg:97.99ms +step:1299/1670 train_time:127295ms step_avg:97.99ms +step:1300/1670 train_time:127392ms step_avg:97.99ms +step:1301/1670 train_time:127489ms step_avg:97.99ms +step:1302/1670 train_time:127586ms step_avg:97.99ms +step:1303/1670 train_time:127682ms step_avg:97.99ms +step:1304/1670 train_time:127780ms step_avg:97.99ms +step:1305/1670 train_time:127879ms step_avg:97.99ms +step:1306/1670 train_time:127980ms step_avg:97.99ms +step:1307/1670 train_time:128079ms step_avg:97.99ms +step:1308/1670 train_time:128177ms step_avg:97.99ms +step:1309/1670 train_time:128275ms step_avg:97.99ms +step:1310/1670 train_time:128372ms step_avg:97.99ms +step:1311/1670 train_time:128469ms step_avg:97.99ms +step:1312/1670 train_time:128567ms step_avg:97.99ms +step:1313/1670 train_time:128664ms step_avg:97.99ms +step:1314/1670 train_time:128760ms step_avg:97.99ms +step:1315/1670 train_time:128858ms step_avg:97.99ms +step:1316/1670 train_time:128957ms step_avg:97.99ms +step:1317/1670 train_time:129058ms step_avg:97.99ms +step:1318/1670 train_time:129157ms step_avg:97.99ms +step:1319/1670 train_time:129256ms step_avg:98.00ms +step:1320/1670 train_time:129354ms step_avg:98.00ms +step:1321/1670 train_time:129451ms step_avg:97.99ms +step:1322/1670 train_time:129549ms step_avg:97.99ms +step:1323/1670 train_time:129647ms step_avg:97.99ms +step:1324/1670 train_time:129744ms step_avg:97.99ms +step:1325/1670 train_time:129841ms step_avg:97.99ms +step:1326/1670 train_time:129938ms step_avg:97.99ms +step:1327/1670 train_time:130037ms step_avg:97.99ms +step:1328/1670 train_time:130136ms step_avg:97.99ms +step:1329/1670 train_time:130234ms step_avg:97.99ms +step:1330/1670 train_time:130332ms step_avg:97.99ms +step:1331/1670 train_time:130430ms step_avg:97.99ms +step:1332/1670 train_time:130527ms step_avg:97.99ms +step:1333/1670 train_time:130623ms step_avg:97.99ms +step:1334/1670 train_time:130720ms step_avg:97.99ms +step:1335/1670 train_time:130817ms step_avg:97.99ms +step:1336/1670 train_time:130915ms step_avg:97.99ms +step:1337/1670 train_time:131014ms step_avg:97.99ms +step:1338/1670 train_time:131114ms step_avg:97.99ms +step:1339/1670 train_time:131212ms step_avg:97.99ms +step:1340/1670 train_time:131310ms step_avg:97.99ms +step:1341/1670 train_time:131409ms step_avg:97.99ms +step:1342/1670 train_time:131507ms step_avg:97.99ms +step:1343/1670 train_time:131604ms step_avg:97.99ms +step:1344/1670 train_time:131701ms step_avg:97.99ms +step:1345/1670 train_time:131798ms step_avg:97.99ms +step:1346/1670 train_time:131897ms step_avg:97.99ms +step:1347/1670 train_time:131994ms step_avg:97.99ms +step:1348/1670 train_time:132093ms step_avg:97.99ms +step:1349/1670 train_time:132190ms step_avg:97.99ms +step:1350/1670 train_time:132288ms step_avg:97.99ms +step:1351/1670 train_time:132386ms step_avg:97.99ms +step:1352/1670 train_time:132483ms step_avg:97.99ms +step:1353/1670 train_time:132581ms step_avg:97.99ms +step:1354/1670 train_time:132678ms step_avg:97.99ms +step:1355/1670 train_time:132776ms step_avg:97.99ms +step:1356/1670 train_time:132875ms step_avg:97.99ms +step:1357/1670 train_time:132973ms step_avg:97.99ms +step:1358/1670 train_time:133071ms step_avg:97.99ms +step:1359/1670 train_time:133169ms step_avg:97.99ms +step:1360/1670 train_time:133267ms step_avg:97.99ms +step:1361/1670 train_time:133365ms step_avg:97.99ms +step:1362/1670 train_time:133462ms step_avg:97.99ms +step:1363/1670 train_time:133559ms step_avg:97.99ms +step:1364/1670 train_time:133657ms step_avg:97.99ms +step:1365/1670 train_time:133754ms step_avg:97.99ms +step:1366/1670 train_time:133852ms step_avg:97.99ms +step:1367/1670 train_time:133949ms step_avg:97.99ms +step:1368/1670 train_time:134047ms step_avg:97.99ms +step:1369/1670 train_time:134144ms step_avg:97.99ms +step:1370/1670 train_time:134241ms step_avg:97.99ms +step:1371/1670 train_time:134339ms step_avg:97.99ms +step:1372/1670 train_time:134437ms step_avg:97.99ms +step:1373/1670 train_time:134536ms step_avg:97.99ms +step:1374/1670 train_time:134634ms step_avg:97.99ms +step:1375/1670 train_time:134733ms step_avg:97.99ms +step:1375/1670 val_loss:3.3445 train_time:134829ms step_avg:98.06ms +step:1376/1670 train_time:134853ms step_avg:98.00ms +step:1377/1670 train_time:134934ms step_avg:97.99ms +step:1378/1670 train_time:135033ms step_avg:97.99ms +step:1379/1670 train_time:135130ms step_avg:97.99ms +step:1380/1670 train_time:135227ms step_avg:97.99ms +step:1381/1670 train_time:135324ms step_avg:97.99ms +step:1382/1670 train_time:135421ms step_avg:97.99ms +step:1383/1670 train_time:135518ms step_avg:97.99ms +step:1384/1670 train_time:135616ms step_avg:97.99ms +step:1385/1670 train_time:135714ms step_avg:97.99ms +step:1386/1670 train_time:135813ms step_avg:97.99ms +step:1387/1670 train_time:135913ms step_avg:97.99ms +step:1388/1670 train_time:136012ms step_avg:97.99ms +step:1389/1670 train_time:136110ms step_avg:97.99ms +step:1390/1670 train_time:136208ms step_avg:97.99ms +step:1391/1670 train_time:136305ms step_avg:97.99ms +step:1392/1670 train_time:136403ms step_avg:97.99ms +step:1393/1670 train_time:136500ms step_avg:97.99ms +step:1394/1670 train_time:136597ms step_avg:97.99ms +step:1395/1670 train_time:136695ms step_avg:97.99ms +step:1396/1670 train_time:136794ms step_avg:97.99ms +step:1397/1670 train_time:136893ms step_avg:97.99ms +step:1398/1670 train_time:136990ms step_avg:97.99ms +step:1399/1670 train_time:137089ms step_avg:97.99ms +step:1400/1670 train_time:137186ms step_avg:97.99ms +step:1401/1670 train_time:137284ms step_avg:97.99ms +step:1402/1670 train_time:137382ms step_avg:97.99ms +step:1403/1670 train_time:137479ms step_avg:97.99ms +step:1404/1670 train_time:137576ms step_avg:97.99ms +step:1405/1670 train_time:137674ms step_avg:97.99ms +step:1406/1670 train_time:137772ms step_avg:97.99ms +step:1407/1670 train_time:137870ms step_avg:97.99ms +step:1408/1670 train_time:137968ms step_avg:97.99ms +step:1409/1670 train_time:138066ms step_avg:97.99ms +step:1410/1670 train_time:138164ms step_avg:97.99ms +step:1411/1670 train_time:138263ms step_avg:97.99ms +step:1412/1670 train_time:138360ms step_avg:97.99ms +step:1413/1670 train_time:138457ms step_avg:97.99ms +step:1414/1670 train_time:138555ms step_avg:97.99ms +step:1415/1670 train_time:138652ms step_avg:97.99ms +step:1416/1670 train_time:138749ms step_avg:97.99ms +step:1417/1670 train_time:138847ms step_avg:97.99ms +step:1418/1670 train_time:138945ms step_avg:97.99ms +step:1419/1670 train_time:139043ms step_avg:97.99ms +step:1420/1670 train_time:139143ms step_avg:97.99ms +step:1421/1670 train_time:139241ms step_avg:97.99ms +step:1422/1670 train_time:139340ms step_avg:97.99ms +step:1423/1670 train_time:139438ms step_avg:97.99ms +step:1424/1670 train_time:139535ms step_avg:97.99ms +step:1425/1670 train_time:139632ms step_avg:97.99ms +step:1426/1670 train_time:139730ms step_avg:97.99ms +step:1427/1670 train_time:139828ms step_avg:97.99ms +step:1428/1670 train_time:139925ms step_avg:97.99ms +step:1429/1670 train_time:140024ms step_avg:97.99ms +step:1430/1670 train_time:140122ms step_avg:97.99ms +step:1431/1670 train_time:140221ms step_avg:97.99ms +step:1432/1670 train_time:140318ms step_avg:97.99ms +step:1433/1670 train_time:140416ms step_avg:97.99ms +step:1434/1670 train_time:140513ms step_avg:97.99ms +step:1435/1670 train_time:140610ms step_avg:97.99ms +step:1436/1670 train_time:140708ms step_avg:97.99ms +step:1437/1670 train_time:140806ms step_avg:97.99ms +step:1438/1670 train_time:140903ms step_avg:97.99ms +step:1439/1670 train_time:141002ms step_avg:97.99ms +step:1440/1670 train_time:141099ms step_avg:97.99ms +step:1441/1670 train_time:141198ms step_avg:97.99ms +step:1442/1670 train_time:141296ms step_avg:97.99ms +step:1443/1670 train_time:141393ms step_avg:97.99ms +step:1444/1670 train_time:141490ms step_avg:97.98ms +step:1445/1670 train_time:141588ms step_avg:97.98ms +step:1446/1670 train_time:141685ms step_avg:97.98ms +step:1447/1670 train_time:141784ms step_avg:97.99ms +step:1448/1670 train_time:141883ms step_avg:97.99ms +step:1449/1670 train_time:141981ms step_avg:97.99ms +step:1450/1670 train_time:142080ms step_avg:97.99ms +step:1451/1670 train_time:142178ms step_avg:97.99ms +step:1452/1670 train_time:142276ms step_avg:97.99ms +step:1453/1670 train_time:142373ms step_avg:97.99ms +step:1454/1670 train_time:142471ms step_avg:97.99ms +step:1455/1670 train_time:142568ms step_avg:97.99ms +step:1456/1670 train_time:142666ms step_avg:97.99ms +step:1457/1670 train_time:142765ms step_avg:97.99ms +step:1458/1670 train_time:142862ms step_avg:97.98ms +step:1459/1670 train_time:142960ms step_avg:97.98ms +step:1460/1670 train_time:143058ms step_avg:97.98ms +step:1461/1670 train_time:143156ms step_avg:97.98ms +step:1462/1670 train_time:143253ms step_avg:97.98ms +step:1463/1670 train_time:143350ms step_avg:97.98ms +step:1464/1670 train_time:143448ms step_avg:97.98ms +step:1465/1670 train_time:143546ms step_avg:97.98ms +step:1466/1670 train_time:143644ms step_avg:97.98ms +step:1467/1670 train_time:143743ms step_avg:97.98ms +step:1468/1670 train_time:143841ms step_avg:97.98ms +step:1469/1670 train_time:143939ms step_avg:97.98ms +step:1470/1670 train_time:144036ms step_avg:97.98ms +step:1471/1670 train_time:144133ms step_avg:97.98ms +step:1472/1670 train_time:144231ms step_avg:97.98ms +step:1473/1670 train_time:144328ms step_avg:97.98ms +step:1474/1670 train_time:144426ms step_avg:97.98ms +step:1475/1670 train_time:144525ms step_avg:97.98ms +step:1476/1670 train_time:144623ms step_avg:97.98ms +step:1477/1670 train_time:144722ms step_avg:97.98ms +step:1478/1670 train_time:144821ms step_avg:97.98ms +step:1479/1670 train_time:144918ms step_avg:97.98ms +step:1480/1670 train_time:145017ms step_avg:97.98ms +step:1481/1670 train_time:145114ms step_avg:97.98ms +step:1482/1670 train_time:145211ms step_avg:97.98ms +step:1483/1670 train_time:145309ms step_avg:97.98ms +step:1484/1670 train_time:145406ms step_avg:97.98ms +step:1485/1670 train_time:145675ms step_avg:98.10ms +step:1486/1670 train_time:145879ms step_avg:98.17ms +step:1487/1670 train_time:145976ms step_avg:98.17ms +step:1488/1670 train_time:146072ms step_avg:98.17ms +step:1489/1670 train_time:146168ms step_avg:98.17ms +step:1490/1670 train_time:146265ms step_avg:98.16ms +step:1491/1670 train_time:146361ms step_avg:98.16ms +step:1492/1670 train_time:146457ms step_avg:98.16ms +step:1493/1670 train_time:146554ms step_avg:98.16ms +step:1494/1670 train_time:146650ms step_avg:98.16ms +step:1495/1670 train_time:146750ms step_avg:98.16ms +step:1496/1670 train_time:146852ms step_avg:98.16ms +step:1497/1670 train_time:146951ms step_avg:98.16ms +step:1498/1670 train_time:147049ms step_avg:98.16ms +step:1499/1670 train_time:147146ms step_avg:98.16ms +step:1500/1670 train_time:147243ms step_avg:98.16ms +step:1500/1670 val_loss:3.3122 train_time:147340ms step_avg:98.23ms +step:1501/1670 train_time:147363ms step_avg:98.18ms +step:1502/1670 train_time:147444ms step_avg:98.17ms +step:1503/1670 train_time:147548ms step_avg:98.17ms +step:1504/1670 train_time:147646ms step_avg:98.17ms +step:1505/1670 train_time:147744ms step_avg:98.17ms +step:1506/1670 train_time:147841ms step_avg:98.17ms +step:1507/1670 train_time:147937ms step_avg:98.17ms +step:1508/1670 train_time:148034ms step_avg:98.17ms +step:1509/1670 train_time:148131ms step_avg:98.16ms +step:1510/1670 train_time:148228ms step_avg:98.16ms +step:1511/1670 train_time:148327ms step_avg:98.16ms +step:1512/1670 train_time:148430ms step_avg:98.17ms +step:1513/1670 train_time:148530ms step_avg:98.17ms +step:1514/1670 train_time:148630ms step_avg:98.17ms +step:1515/1670 train_time:148727ms step_avg:98.17ms +step:1516/1670 train_time:148825ms step_avg:98.17ms +step:1517/1670 train_time:148922ms step_avg:98.17ms +step:1518/1670 train_time:149019ms step_avg:98.17ms +step:1519/1670 train_time:149115ms step_avg:98.17ms +step:1520/1670 train_time:149211ms step_avg:98.17ms +step:1521/1670 train_time:149310ms step_avg:98.17ms +step:1522/1670 train_time:149411ms step_avg:98.17ms +step:1523/1670 train_time:149511ms step_avg:98.17ms +step:1524/1670 train_time:149610ms step_avg:98.17ms +step:1525/1670 train_time:149709ms step_avg:98.17ms +step:1526/1670 train_time:149808ms step_avg:98.17ms +step:1527/1670 train_time:149905ms step_avg:98.17ms +step:1528/1670 train_time:150003ms step_avg:98.17ms +step:1529/1670 train_time:150099ms step_avg:98.17ms +step:1530/1670 train_time:150196ms step_avg:98.17ms +step:1531/1670 train_time:150293ms step_avg:98.17ms +step:1532/1670 train_time:150392ms step_avg:98.17ms +step:1533/1670 train_time:150491ms step_avg:98.17ms +step:1534/1670 train_time:150590ms step_avg:98.17ms +step:1535/1670 train_time:150690ms step_avg:98.17ms +step:1536/1670 train_time:150788ms step_avg:98.17ms +step:1537/1670 train_time:150887ms step_avg:98.17ms +step:1538/1670 train_time:150985ms step_avg:98.17ms +step:1539/1670 train_time:151083ms step_avg:98.17ms +step:1540/1670 train_time:151180ms step_avg:98.17ms +step:1541/1670 train_time:151278ms step_avg:98.17ms +step:1542/1670 train_time:151376ms step_avg:98.17ms +step:1543/1670 train_time:151474ms step_avg:98.17ms +step:1544/1670 train_time:151570ms step_avg:98.17ms +step:1545/1670 train_time:151669ms step_avg:98.17ms +step:1546/1670 train_time:151767ms step_avg:98.17ms +step:1547/1670 train_time:151864ms step_avg:98.17ms +step:1548/1670 train_time:151963ms step_avg:98.17ms +step:1549/1670 train_time:152061ms step_avg:98.17ms +step:1550/1670 train_time:152159ms step_avg:98.17ms +step:1551/1670 train_time:152256ms step_avg:98.17ms +step:1552/1670 train_time:152354ms step_avg:98.17ms +step:1553/1670 train_time:152452ms step_avg:98.17ms +step:1554/1670 train_time:152550ms step_avg:98.17ms +step:1555/1670 train_time:152649ms step_avg:98.17ms +step:1556/1670 train_time:152747ms step_avg:98.17ms +step:1557/1670 train_time:152846ms step_avg:98.17ms +step:1558/1670 train_time:152943ms step_avg:98.17ms +step:1559/1670 train_time:153040ms step_avg:98.17ms +step:1560/1670 train_time:153138ms step_avg:98.17ms +step:1561/1670 train_time:153235ms step_avg:98.16ms +step:1562/1670 train_time:153332ms step_avg:98.16ms +step:1563/1670 train_time:153431ms step_avg:98.16ms +step:1564/1670 train_time:153529ms step_avg:98.16ms +step:1565/1670 train_time:153628ms step_avg:98.16ms +step:1566/1670 train_time:153726ms step_avg:98.16ms +step:1567/1670 train_time:153824ms step_avg:98.16ms +step:1568/1670 train_time:153921ms step_avg:98.16ms +step:1569/1670 train_time:154020ms step_avg:98.16ms +step:1570/1670 train_time:154116ms step_avg:98.16ms +step:1571/1670 train_time:154212ms step_avg:98.16ms +step:1572/1670 train_time:154310ms step_avg:98.16ms +step:1573/1670 train_time:154408ms step_avg:98.16ms +step:1574/1670 train_time:154507ms step_avg:98.16ms +step:1575/1670 train_time:154606ms step_avg:98.16ms +step:1576/1670 train_time:154704ms step_avg:98.16ms +step:1577/1670 train_time:154801ms step_avg:98.16ms +step:1578/1670 train_time:154898ms step_avg:98.16ms +step:1579/1670 train_time:154996ms step_avg:98.16ms +step:1580/1670 train_time:155094ms step_avg:98.16ms +step:1581/1670 train_time:155193ms step_avg:98.16ms +step:1582/1670 train_time:155290ms step_avg:98.16ms +step:1583/1670 train_time:155388ms step_avg:98.16ms +step:1584/1670 train_time:155486ms step_avg:98.16ms +step:1585/1670 train_time:155584ms step_avg:98.16ms +step:1586/1670 train_time:155682ms step_avg:98.16ms +step:1587/1670 train_time:155779ms step_avg:98.16ms +step:1588/1670 train_time:155876ms step_avg:98.16ms +step:1589/1670 train_time:155974ms step_avg:98.16ms +step:1590/1670 train_time:156071ms step_avg:98.16ms +step:1591/1670 train_time:156170ms step_avg:98.16ms +step:1592/1670 train_time:156268ms step_avg:98.16ms +step:1593/1670 train_time:156366ms step_avg:98.16ms +step:1594/1670 train_time:156463ms step_avg:98.16ms +step:1595/1670 train_time:156561ms step_avg:98.16ms +step:1596/1670 train_time:156658ms step_avg:98.16ms +step:1597/1670 train_time:156756ms step_avg:98.16ms +step:1598/1670 train_time:156854ms step_avg:98.16ms +step:1599/1670 train_time:156951ms step_avg:98.16ms +step:1600/1670 train_time:157050ms step_avg:98.16ms +step:1601/1670 train_time:157149ms step_avg:98.16ms +step:1602/1670 train_time:157246ms step_avg:98.16ms +step:1603/1670 train_time:157344ms step_avg:98.16ms +step:1604/1670 train_time:157441ms step_avg:98.16ms +step:1605/1670 train_time:157539ms step_avg:98.16ms +step:1606/1670 train_time:157637ms step_avg:98.15ms +step:1607/1670 train_time:157733ms step_avg:98.15ms +step:1608/1670 train_time:157831ms step_avg:98.15ms +step:1609/1670 train_time:157929ms step_avg:98.15ms +step:1610/1670 train_time:158028ms step_avg:98.15ms +step:1611/1670 train_time:158126ms step_avg:98.15ms +step:1612/1670 train_time:158224ms step_avg:98.15ms +step:1613/1670 train_time:158322ms step_avg:98.15ms +step:1614/1670 train_time:158420ms step_avg:98.15ms +step:1615/1670 train_time:158518ms step_avg:98.15ms +step:1616/1670 train_time:158616ms step_avg:98.15ms +step:1617/1670 train_time:158713ms step_avg:98.15ms +step:1618/1670 train_time:158812ms step_avg:98.15ms +step:1619/1670 train_time:158910ms step_avg:98.15ms +step:1620/1670 train_time:159008ms step_avg:98.15ms +step:1621/1670 train_time:159106ms step_avg:98.15ms +step:1622/1670 train_time:159204ms step_avg:98.15ms +step:1623/1670 train_time:159302ms step_avg:98.15ms +step:1624/1670 train_time:159399ms step_avg:98.15ms +step:1625/1670 train_time:159497ms step_avg:98.15ms +step:1625/1670 val_loss:3.2853 train_time:159593ms step_avg:98.21ms +step:1626/1670 train_time:159619ms step_avg:98.17ms +step:1627/1670 train_time:159701ms step_avg:98.16ms +step:1628/1670 train_time:159801ms step_avg:98.16ms +step:1629/1670 train_time:159899ms step_avg:98.16ms +step:1630/1670 train_time:159996ms step_avg:98.16ms +step:1631/1670 train_time:160093ms step_avg:98.16ms +step:1632/1670 train_time:160190ms step_avg:98.16ms +step:1633/1670 train_time:160286ms step_avg:98.15ms +step:1634/1670 train_time:160384ms step_avg:98.15ms +step:1635/1670 train_time:160481ms step_avg:98.15ms +step:1636/1670 train_time:160582ms step_avg:98.16ms +step:1637/1670 train_time:160683ms step_avg:98.16ms +step:1638/1670 train_time:160782ms step_avg:98.16ms +step:1639/1670 train_time:160879ms step_avg:98.16ms +step:1640/1670 train_time:160977ms step_avg:98.16ms +step:1641/1670 train_time:161076ms step_avg:98.16ms +step:1642/1670 train_time:161175ms step_avg:98.16ms +step:1643/1670 train_time:161273ms step_avg:98.16ms +step:1644/1670 train_time:161371ms step_avg:98.16ms +step:1645/1670 train_time:161467ms step_avg:98.16ms +step:1646/1670 train_time:161564ms step_avg:98.16ms +step:1647/1670 train_time:161664ms step_avg:98.16ms +step:1648/1670 train_time:161763ms step_avg:98.16ms +step:1649/1670 train_time:161861ms step_avg:98.16ms +step:1650/1670 train_time:161959ms step_avg:98.16ms +step:1651/1670 train_time:162057ms step_avg:98.16ms +step:1652/1670 train_time:162156ms step_avg:98.16ms +step:1653/1670 train_time:162254ms step_avg:98.16ms +step:1654/1670 train_time:162351ms step_avg:98.16ms +step:1655/1670 train_time:162448ms step_avg:98.16ms +step:1656/1670 train_time:162547ms step_avg:98.16ms +step:1657/1670 train_time:162645ms step_avg:98.16ms +step:1658/1670 train_time:162743ms step_avg:98.16ms +step:1659/1670 train_time:162841ms step_avg:98.16ms +step:1660/1670 train_time:162939ms step_avg:98.16ms +step:1661/1670 train_time:163039ms step_avg:98.16ms +step:1662/1670 train_time:163136ms step_avg:98.16ms +step:1663/1670 train_time:163235ms step_avg:98.16ms +step:1664/1670 train_time:163333ms step_avg:98.16ms +step:1665/1670 train_time:163431ms step_avg:98.16ms +step:1666/1670 train_time:163529ms step_avg:98.16ms +step:1667/1670 train_time:163627ms step_avg:98.16ms +step:1668/1670 train_time:163725ms step_avg:98.16ms +step:1669/1670 train_time:163822ms step_avg:98.16ms +step:1670/1670 train_time:163919ms step_avg:98.15ms +step:1670/1670 val_loss:3.2774 train_time:164016ms step_avg:98.21ms +peak memory allocated: 34000 MiB reserved: 49496 MiB diff --git a/records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt b/records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt new file mode 100644 index 000000000..a0c9c3b35 --- /dev/null +++ b/records/090325_FA3/f4f7b0aa-07a1-49a2-903f-97cd5277e73c.txt @@ -0,0 +1,2814 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import numpy as np +import triton +import triton.language as tl +from flash_attn_interface import flash_attn_varlen_func +import torch._dynamo as dynamo +dynamo.config.recompile_limit = 64 + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + # sparse gated attention to enable context based no-op by @classiclarryd + self.attn_gate_dim = 12 + self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) + self.attn_gate.weight.detach().zero_() + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 + + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) + y = y.view(B, T, self.num_heads, self.head_dim) + y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=2**-9, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 1.0 + self.scalars.lr_mul = 5.0 + + + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size + bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(bm_sizes) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / 7.5) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +BOS_ID = 50256 + +class BOSFinder: + # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 + self.world_size = world_size + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) + starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + + idx = self.i + for r in range(self.world_size): + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 + self.i = idx + + return starts, ends + +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps + + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") + + file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case + + while True: + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 + + if align_to_bos: + try: + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) + except StopIteration: + # This shard is exhausted, load the next one in the next loop iteration. + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) + continue + + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + + else: + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) + + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens + + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths + + new_params = yield ( + _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) + ) + + if new_params is not None: + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps + + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 + # optimization + num_iterations: int = 1670 # number of iterations to run + cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + run_id: str = str(uuid.uuid4()) + val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint: bool = False + # attention masking + block_size: int = 128 + ws_schedule: tuple = (3, 7, 11) + +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = args.run_id + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") + +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT( + vocab_size=50257, + num_layers=12, + num_heads=6, + model_dim=768, + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) +).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations + assert 0 <= x < 1 + lr = 1.0 + if x >= 1 - args.cooldown_frac: + w = (1 - x) / args.cooldown_frac + lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 + ws_idx = int(len(args.ws_schedule) * x) + return args.ws_schedule[ws_idx] + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 30 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +for step in range(warmup_steps): + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + ws = get_ws(step) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Wed Sep 3 20:04:19 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 29C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 33C P0 123W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 32C P0 119W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 28C P0 115W / 700W | 1520MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 53094 C /usr/bin/python 0MiB | +| 0 N/A N/A 53095 C /usr/bin/python 0MiB | +| 0 N/A N/A 53096 C /usr/bin/python 0MiB | +| 0 N/A N/A 53097 C /usr/bin/python 0MiB | +| 0 N/A N/A 53098 C /usr/bin/python 0MiB | +| 0 N/A N/A 53099 C /usr/bin/python 0MiB | +| 0 N/A N/A 53100 C /usr/bin/python 0MiB | +| 0 N/A N/A 53101 C /usr/bin/python 0MiB | +| 1 N/A N/A 53095 C /usr/bin/python 0MiB | +| 2 N/A N/A 53096 C /usr/bin/python 0MiB | +| 3 N/A N/A 53097 C /usr/bin/python 0MiB | +| 4 N/A N/A 53098 C /usr/bin/python 0MiB | +| 5 N/A N/A 53099 C /usr/bin/python 0MiB | +| 6 N/A N/A 53100 C /usr/bin/python 0MiB | +| 7 N/A N/A 53101 C /usr/bin/python 0MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1670 val_loss:10.8258 train_time:0ms step_avg:0.02ms +step:1/1670 train_time:375ms step_avg:374.99ms +step:2/1670 train_time:398ms step_avg:199.22ms +step:3/1670 train_time:469ms step_avg:156.43ms +step:4/1670 train_time:563ms step_avg:140.76ms +step:5/1670 train_time:658ms step_avg:131.56ms +step:6/1670 train_time:752ms step_avg:125.32ms +step:7/1670 train_time:847ms step_avg:120.96ms +step:8/1670 train_time:942ms step_avg:117.75ms +step:9/1670 train_time:1037ms step_avg:115.24ms +step:10/1670 train_time:1133ms step_avg:113.32ms +step:11/1670 train_time:1228ms step_avg:111.62ms +step:12/1670 train_time:1326ms step_avg:110.52ms +step:13/1670 train_time:1424ms step_avg:109.56ms +step:14/1670 train_time:1520ms step_avg:108.58ms +step:15/1670 train_time:1616ms step_avg:107.72ms +step:16/1670 train_time:1711ms step_avg:106.91ms +step:17/1670 train_time:1806ms step_avg:106.22ms +step:18/1670 train_time:1901ms step_avg:105.63ms +step:19/1670 train_time:1996ms step_avg:105.06ms +step:20/1670 train_time:2091ms step_avg:104.57ms +step:21/1670 train_time:2187ms step_avg:104.15ms +step:22/1670 train_time:2283ms step_avg:103.78ms +step:23/1670 train_time:2380ms step_avg:103.49ms +step:24/1670 train_time:2477ms step_avg:103.21ms +step:25/1670 train_time:2572ms step_avg:102.89ms +step:26/1670 train_time:2669ms step_avg:102.66ms +step:27/1670 train_time:2764ms step_avg:102.39ms +step:28/1670 train_time:2860ms step_avg:102.14ms +step:29/1670 train_time:2955ms step_avg:101.91ms +step:30/1670 train_time:3051ms step_avg:101.70ms +step:31/1670 train_time:3147ms step_avg:101.51ms +step:32/1670 train_time:3243ms step_avg:101.34ms +step:33/1670 train_time:3338ms step_avg:101.16ms +step:34/1670 train_time:3434ms step_avg:101.00ms +step:35/1670 train_time:3530ms step_avg:100.85ms +step:36/1670 train_time:3627ms step_avg:100.75ms +step:37/1670 train_time:3723ms step_avg:100.61ms +step:38/1670 train_time:3818ms step_avg:100.46ms +step:39/1670 train_time:3913ms step_avg:100.33ms +step:40/1670 train_time:4008ms step_avg:100.19ms +step:41/1670 train_time:4104ms step_avg:100.10ms +step:42/1670 train_time:4199ms step_avg:99.98ms +step:43/1670 train_time:4295ms step_avg:99.88ms +step:44/1670 train_time:4391ms step_avg:99.79ms +step:45/1670 train_time:4487ms step_avg:99.72ms +step:46/1670 train_time:4583ms step_avg:99.63ms +step:47/1670 train_time:4679ms step_avg:99.56ms +step:48/1670 train_time:4775ms step_avg:99.48ms +step:49/1670 train_time:4871ms step_avg:99.40ms +step:50/1670 train_time:4966ms step_avg:99.32ms +step:51/1670 train_time:5061ms step_avg:99.24ms +step:52/1670 train_time:5156ms step_avg:99.16ms +step:53/1670 train_time:5252ms step_avg:99.10ms +step:54/1670 train_time:5348ms step_avg:99.04ms +step:55/1670 train_time:5444ms step_avg:98.97ms +step:56/1670 train_time:5540ms step_avg:98.92ms +step:57/1670 train_time:5635ms step_avg:98.86ms +step:58/1670 train_time:5731ms step_avg:98.80ms +step:59/1670 train_time:5826ms step_avg:98.75ms +step:60/1670 train_time:5923ms step_avg:98.71ms +step:61/1670 train_time:6018ms step_avg:98.66ms +step:62/1670 train_time:6113ms step_avg:98.59ms +step:63/1670 train_time:6208ms step_avg:98.55ms +step:64/1670 train_time:6305ms step_avg:98.51ms +step:65/1670 train_time:6401ms step_avg:98.48ms +step:66/1670 train_time:6497ms step_avg:98.43ms +step:67/1670 train_time:6593ms step_avg:98.40ms +step:68/1670 train_time:6688ms step_avg:98.35ms +step:69/1670 train_time:6783ms step_avg:98.31ms +step:70/1670 train_time:6879ms step_avg:98.28ms +step:71/1670 train_time:6975ms step_avg:98.24ms +step:72/1670 train_time:7070ms step_avg:98.20ms +step:73/1670 train_time:7166ms step_avg:98.17ms +step:74/1670 train_time:7262ms step_avg:98.14ms +step:75/1670 train_time:7358ms step_avg:98.11ms +step:76/1670 train_time:7454ms step_avg:98.08ms +step:77/1670 train_time:7551ms step_avg:98.06ms +step:78/1670 train_time:7647ms step_avg:98.04ms +step:79/1670 train_time:7743ms step_avg:98.02ms +step:80/1670 train_time:7839ms step_avg:97.98ms +step:81/1670 train_time:7934ms step_avg:97.95ms +step:82/1670 train_time:8030ms step_avg:97.92ms +step:83/1670 train_time:8126ms step_avg:97.90ms +step:84/1670 train_time:8222ms step_avg:97.88ms +step:85/1670 train_time:8317ms step_avg:97.85ms +step:86/1670 train_time:8412ms step_avg:97.82ms +step:87/1670 train_time:8508ms step_avg:97.79ms +step:88/1670 train_time:8604ms step_avg:97.77ms +step:89/1670 train_time:8699ms step_avg:97.75ms +step:90/1670 train_time:8796ms step_avg:97.73ms +step:91/1670 train_time:8891ms step_avg:97.70ms +step:92/1670 train_time:8986ms step_avg:97.67ms +step:93/1670 train_time:9081ms step_avg:97.65ms +step:94/1670 train_time:9177ms step_avg:97.62ms +step:95/1670 train_time:9272ms step_avg:97.60ms +step:96/1670 train_time:9368ms step_avg:97.58ms +step:97/1670 train_time:9464ms step_avg:97.56ms +step:98/1670 train_time:9559ms step_avg:97.54ms +step:99/1670 train_time:9655ms step_avg:97.52ms +step:100/1670 train_time:9750ms step_avg:97.50ms +step:101/1670 train_time:9847ms step_avg:97.49ms +step:102/1670 train_time:9943ms step_avg:97.48ms +step:103/1670 train_time:10039ms step_avg:97.46ms +step:104/1670 train_time:10134ms step_avg:97.44ms +step:105/1670 train_time:10229ms step_avg:97.42ms +step:106/1670 train_time:10325ms step_avg:97.40ms +step:107/1670 train_time:10420ms step_avg:97.38ms +step:108/1670 train_time:10515ms step_avg:97.36ms +step:109/1670 train_time:10611ms step_avg:97.35ms +step:110/1670 train_time:10707ms step_avg:97.33ms +step:111/1670 train_time:10803ms step_avg:97.32ms +step:112/1670 train_time:10898ms step_avg:97.30ms +step:113/1670 train_time:10994ms step_avg:97.29ms +step:114/1670 train_time:11090ms step_avg:97.28ms +step:115/1670 train_time:11186ms step_avg:97.27ms +step:116/1670 train_time:11282ms step_avg:97.26ms +step:117/1670 train_time:11378ms step_avg:97.25ms +step:118/1670 train_time:11474ms step_avg:97.23ms +step:119/1670 train_time:11569ms step_avg:97.22ms +step:120/1670 train_time:11666ms step_avg:97.22ms +step:121/1670 train_time:11762ms step_avg:97.20ms +step:122/1670 train_time:11857ms step_avg:97.19ms +step:123/1670 train_time:11952ms step_avg:97.17ms +step:124/1670 train_time:12047ms step_avg:97.16ms +step:125/1670 train_time:12143ms step_avg:97.15ms +step:125/1670 val_loss:4.2921 train_time:12238ms step_avg:97.91ms +step:126/1670 train_time:12265ms step_avg:97.34ms +step:127/1670 train_time:12345ms step_avg:97.20ms +step:128/1670 train_time:12448ms step_avg:97.25ms +step:129/1670 train_time:12546ms step_avg:97.26ms +step:130/1670 train_time:12641ms step_avg:97.24ms +step:131/1670 train_time:12737ms step_avg:97.23ms +step:132/1670 train_time:12831ms step_avg:97.20ms +step:133/1670 train_time:12925ms step_avg:97.18ms +step:134/1670 train_time:13020ms step_avg:97.16ms +step:135/1670 train_time:13115ms step_avg:97.15ms +step:136/1670 train_time:13209ms step_avg:97.13ms +step:137/1670 train_time:13305ms step_avg:97.12ms +step:138/1670 train_time:13405ms step_avg:97.14ms +step:139/1670 train_time:13503ms step_avg:97.15ms +step:140/1670 train_time:13600ms step_avg:97.14ms +step:141/1670 train_time:13695ms step_avg:97.13ms +step:142/1670 train_time:13790ms step_avg:97.11ms +step:143/1670 train_time:13885ms step_avg:97.10ms +step:144/1670 train_time:13980ms step_avg:97.08ms +step:145/1670 train_time:14074ms step_avg:97.06ms +step:146/1670 train_time:14170ms step_avg:97.05ms +step:147/1670 train_time:14265ms step_avg:97.04ms +step:148/1670 train_time:14362ms step_avg:97.04ms +step:149/1670 train_time:14460ms step_avg:97.05ms +step:150/1670 train_time:14557ms step_avg:97.05ms +step:151/1670 train_time:14654ms step_avg:97.04ms +step:152/1670 train_time:14749ms step_avg:97.03ms +step:153/1670 train_time:14844ms step_avg:97.02ms +step:154/1670 train_time:14939ms step_avg:97.01ms +step:155/1670 train_time:15034ms step_avg:96.99ms +step:156/1670 train_time:15129ms step_avg:96.98ms +step:157/1670 train_time:15224ms step_avg:96.97ms +step:158/1670 train_time:15320ms step_avg:96.96ms +step:159/1670 train_time:15417ms step_avg:96.96ms +step:160/1670 train_time:15514ms step_avg:96.96ms +step:161/1670 train_time:15610ms step_avg:96.96ms +step:162/1670 train_time:15705ms step_avg:96.95ms +step:163/1670 train_time:15801ms step_avg:96.94ms +step:164/1670 train_time:15896ms step_avg:96.93ms +step:165/1670 train_time:15991ms step_avg:96.91ms +step:166/1670 train_time:16085ms step_avg:96.90ms +step:167/1670 train_time:16181ms step_avg:96.89ms +step:168/1670 train_time:16277ms step_avg:96.89ms +step:169/1670 train_time:16373ms step_avg:96.88ms +step:170/1670 train_time:16469ms step_avg:96.88ms +step:171/1670 train_time:16565ms step_avg:96.87ms +step:172/1670 train_time:16660ms step_avg:96.86ms +step:173/1670 train_time:16758ms step_avg:96.87ms +step:174/1670 train_time:16853ms step_avg:96.85ms +step:175/1670 train_time:16948ms step_avg:96.85ms +step:176/1670 train_time:17043ms step_avg:96.83ms +step:177/1670 train_time:17139ms step_avg:96.83ms +step:178/1670 train_time:17234ms step_avg:96.82ms +step:179/1670 train_time:17330ms step_avg:96.82ms +step:180/1670 train_time:17426ms step_avg:96.81ms +step:181/1670 train_time:17522ms step_avg:96.81ms +step:182/1670 train_time:17617ms step_avg:96.80ms +step:183/1670 train_time:17713ms step_avg:96.79ms +step:184/1670 train_time:17809ms step_avg:96.79ms +step:185/1670 train_time:17904ms step_avg:96.78ms +step:186/1670 train_time:17999ms step_avg:96.77ms +step:187/1670 train_time:18094ms step_avg:96.76ms +step:188/1670 train_time:18189ms step_avg:96.75ms +step:189/1670 train_time:18285ms step_avg:96.74ms +step:190/1670 train_time:18381ms step_avg:96.74ms +step:191/1670 train_time:18477ms step_avg:96.74ms +step:192/1670 train_time:18574ms step_avg:96.74ms +step:193/1670 train_time:18669ms step_avg:96.73ms +step:194/1670 train_time:18766ms step_avg:96.73ms +step:195/1670 train_time:18861ms step_avg:96.72ms +step:196/1670 train_time:18957ms step_avg:96.72ms +step:197/1670 train_time:19052ms step_avg:96.71ms +step:198/1670 train_time:19146ms step_avg:96.70ms +step:199/1670 train_time:19242ms step_avg:96.69ms +step:200/1670 train_time:19337ms step_avg:96.69ms +step:201/1670 train_time:19432ms step_avg:96.68ms +step:202/1670 train_time:19528ms step_avg:96.67ms +step:203/1670 train_time:19624ms step_avg:96.67ms +step:204/1670 train_time:19719ms step_avg:96.66ms +step:205/1670 train_time:19816ms step_avg:96.66ms +step:206/1670 train_time:19911ms step_avg:96.65ms +step:207/1670 train_time:20006ms step_avg:96.64ms +step:208/1670 train_time:20101ms step_avg:96.64ms +step:209/1670 train_time:20196ms step_avg:96.63ms +step:210/1670 train_time:20292ms step_avg:96.63ms +step:211/1670 train_time:20387ms step_avg:96.62ms +step:212/1670 train_time:20483ms step_avg:96.62ms +step:213/1670 train_time:20759ms step_avg:97.46ms +step:214/1670 train_time:20853ms step_avg:97.44ms +step:215/1670 train_time:20947ms step_avg:97.43ms +step:216/1670 train_time:21041ms step_avg:97.41ms +step:217/1670 train_time:21136ms step_avg:97.40ms +step:218/1670 train_time:21231ms step_avg:97.39ms +step:219/1670 train_time:21325ms step_avg:97.37ms +step:220/1670 train_time:21420ms step_avg:97.36ms +step:221/1670 train_time:21514ms step_avg:97.35ms +step:222/1670 train_time:21610ms step_avg:97.34ms +step:223/1670 train_time:21708ms step_avg:97.35ms +step:224/1670 train_time:21805ms step_avg:97.34ms +step:225/1670 train_time:21901ms step_avg:97.34ms +step:226/1670 train_time:21996ms step_avg:97.33ms +step:227/1670 train_time:22091ms step_avg:97.32ms +step:228/1670 train_time:22186ms step_avg:97.31ms +step:229/1670 train_time:22281ms step_avg:97.30ms +step:230/1670 train_time:22376ms step_avg:97.29ms +step:231/1670 train_time:22471ms step_avg:97.28ms +step:232/1670 train_time:22565ms step_avg:97.27ms +step:233/1670 train_time:22663ms step_avg:97.27ms +step:234/1670 train_time:22760ms step_avg:97.26ms +step:235/1670 train_time:22857ms step_avg:97.26ms +step:236/1670 train_time:22954ms step_avg:97.26ms +step:237/1670 train_time:23048ms step_avg:97.25ms +step:238/1670 train_time:23143ms step_avg:97.24ms +step:239/1670 train_time:23238ms step_avg:97.23ms +step:240/1670 train_time:23333ms step_avg:97.22ms +step:241/1670 train_time:23428ms step_avg:97.21ms +step:242/1670 train_time:23523ms step_avg:97.20ms +step:243/1670 train_time:23619ms step_avg:97.20ms +step:244/1670 train_time:23714ms step_avg:97.19ms +step:245/1670 train_time:23810ms step_avg:97.18ms +step:246/1670 train_time:23906ms step_avg:97.18ms +step:247/1670 train_time:24002ms step_avg:97.17ms +step:248/1670 train_time:24098ms step_avg:97.17ms +step:249/1670 train_time:24193ms step_avg:97.16ms +step:250/1670 train_time:24288ms step_avg:97.15ms +step:250/1670 val_loss:3.9671 train_time:24382ms step_avg:97.53ms +step:251/1670 train_time:24406ms step_avg:97.24ms +step:252/1670 train_time:24482ms step_avg:97.15ms +step:253/1670 train_time:24581ms step_avg:97.16ms +step:254/1670 train_time:24677ms step_avg:97.15ms +step:255/1670 train_time:24772ms step_avg:97.14ms +step:256/1670 train_time:24866ms step_avg:97.13ms +step:257/1670 train_time:24960ms step_avg:97.12ms +step:258/1670 train_time:25055ms step_avg:97.11ms +step:259/1670 train_time:25150ms step_avg:97.10ms +step:260/1670 train_time:25245ms step_avg:97.10ms +step:261/1670 train_time:25341ms step_avg:97.09ms +step:262/1670 train_time:25439ms step_avg:97.10ms +step:263/1670 train_time:25536ms step_avg:97.10ms +step:264/1670 train_time:25633ms step_avg:97.09ms +step:265/1670 train_time:25729ms step_avg:97.09ms +step:266/1670 train_time:25825ms step_avg:97.09ms +step:267/1670 train_time:25919ms step_avg:97.08ms +step:268/1670 train_time:26014ms step_avg:97.07ms +step:269/1670 train_time:26110ms step_avg:97.06ms +step:270/1670 train_time:26203ms step_avg:97.05ms +step:271/1670 train_time:26299ms step_avg:97.04ms +step:272/1670 train_time:26395ms step_avg:97.04ms +step:273/1670 train_time:26492ms step_avg:97.04ms +step:274/1670 train_time:26588ms step_avg:97.04ms +step:275/1670 train_time:26683ms step_avg:97.03ms +step:276/1670 train_time:26779ms step_avg:97.02ms +step:277/1670 train_time:26875ms step_avg:97.02ms +step:278/1670 train_time:26971ms step_avg:97.02ms +step:279/1670 train_time:27067ms step_avg:97.01ms +step:280/1670 train_time:27161ms step_avg:97.00ms +step:281/1670 train_time:27256ms step_avg:97.00ms +step:282/1670 train_time:27351ms step_avg:96.99ms +step:283/1670 train_time:27447ms step_avg:96.99ms +step:284/1670 train_time:27544ms step_avg:96.98ms +step:285/1670 train_time:27638ms step_avg:96.98ms +step:286/1670 train_time:27735ms step_avg:96.97ms +step:287/1670 train_time:27832ms step_avg:96.97ms +step:288/1670 train_time:27928ms step_avg:96.97ms +step:289/1670 train_time:28024ms step_avg:96.97ms +step:290/1670 train_time:28118ms step_avg:96.96ms +step:291/1670 train_time:28214ms step_avg:96.95ms +step:292/1670 train_time:28310ms step_avg:96.95ms +step:293/1670 train_time:28405ms step_avg:96.95ms +step:294/1670 train_time:28500ms step_avg:96.94ms +step:295/1670 train_time:28595ms step_avg:96.93ms +step:296/1670 train_time:28691ms step_avg:96.93ms +step:297/1670 train_time:28786ms step_avg:96.92ms +step:298/1670 train_time:28882ms step_avg:96.92ms +step:299/1670 train_time:28977ms step_avg:96.91ms +step:300/1670 train_time:29074ms step_avg:96.91ms +step:301/1670 train_time:29170ms step_avg:96.91ms +step:302/1670 train_time:29266ms step_avg:96.91ms +step:303/1670 train_time:29362ms step_avg:96.90ms +step:304/1670 train_time:29457ms step_avg:96.90ms +step:305/1670 train_time:29553ms step_avg:96.89ms +step:306/1670 train_time:29649ms step_avg:96.89ms +step:307/1670 train_time:29744ms step_avg:96.89ms +step:308/1670 train_time:29839ms step_avg:96.88ms +step:309/1670 train_time:29935ms step_avg:96.88ms +step:310/1670 train_time:30031ms step_avg:96.87ms +step:311/1670 train_time:30126ms step_avg:96.87ms +step:312/1670 train_time:30222ms step_avg:96.86ms +step:313/1670 train_time:30318ms step_avg:96.86ms +step:314/1670 train_time:30413ms step_avg:96.86ms +step:315/1670 train_time:30509ms step_avg:96.85ms +step:316/1670 train_time:30605ms step_avg:96.85ms +step:317/1670 train_time:30700ms step_avg:96.85ms +step:318/1670 train_time:30795ms step_avg:96.84ms +step:319/1670 train_time:30891ms step_avg:96.84ms +step:320/1670 train_time:30986ms step_avg:96.83ms +step:321/1670 train_time:31082ms step_avg:96.83ms +step:322/1670 train_time:31177ms step_avg:96.82ms +step:323/1670 train_time:31273ms step_avg:96.82ms +step:324/1670 train_time:31370ms step_avg:96.82ms +step:325/1670 train_time:31465ms step_avg:96.82ms +step:326/1670 train_time:31560ms step_avg:96.81ms +step:327/1670 train_time:31656ms step_avg:96.81ms +step:328/1670 train_time:31752ms step_avg:96.81ms +step:329/1670 train_time:31849ms step_avg:96.80ms +step:330/1670 train_time:31944ms step_avg:96.80ms +step:331/1670 train_time:32039ms step_avg:96.80ms +step:332/1670 train_time:32136ms step_avg:96.79ms +step:333/1670 train_time:32232ms step_avg:96.79ms +step:334/1670 train_time:32328ms step_avg:96.79ms +step:335/1670 train_time:32424ms step_avg:96.79ms +step:336/1670 train_time:32520ms step_avg:96.78ms +step:337/1670 train_time:32614ms step_avg:96.78ms +step:338/1670 train_time:32709ms step_avg:96.77ms +step:339/1670 train_time:32805ms step_avg:96.77ms +step:340/1670 train_time:32900ms step_avg:96.76ms +step:341/1670 train_time:32996ms step_avg:96.76ms +step:342/1670 train_time:33092ms step_avg:96.76ms +step:343/1670 train_time:33188ms step_avg:96.76ms +step:344/1670 train_time:33284ms step_avg:96.75ms +step:345/1670 train_time:33379ms step_avg:96.75ms +step:346/1670 train_time:33475ms step_avg:96.75ms +step:347/1670 train_time:33571ms step_avg:96.75ms +step:348/1670 train_time:33666ms step_avg:96.74ms +step:349/1670 train_time:33762ms step_avg:96.74ms +step:350/1670 train_time:33857ms step_avg:96.73ms +step:351/1670 train_time:33952ms step_avg:96.73ms +step:352/1670 train_time:34048ms step_avg:96.73ms +step:353/1670 train_time:34144ms step_avg:96.73ms +step:354/1670 train_time:34239ms step_avg:96.72ms +step:355/1670 train_time:34335ms step_avg:96.72ms +step:356/1670 train_time:34431ms step_avg:96.72ms +step:357/1670 train_time:34527ms step_avg:96.72ms +step:358/1670 train_time:34622ms step_avg:96.71ms +step:359/1670 train_time:34717ms step_avg:96.70ms +step:360/1670 train_time:34813ms step_avg:96.70ms +step:361/1670 train_time:34909ms step_avg:96.70ms +step:362/1670 train_time:35004ms step_avg:96.70ms +step:363/1670 train_time:35099ms step_avg:96.69ms +step:364/1670 train_time:35195ms step_avg:96.69ms +step:365/1670 train_time:35291ms step_avg:96.69ms +step:366/1670 train_time:35387ms step_avg:96.69ms +step:367/1670 train_time:35482ms step_avg:96.68ms +step:368/1670 train_time:35578ms step_avg:96.68ms +step:369/1670 train_time:35673ms step_avg:96.68ms +step:370/1670 train_time:35769ms step_avg:96.67ms +step:371/1670 train_time:35866ms step_avg:96.67ms +step:372/1670 train_time:35960ms step_avg:96.67ms +step:373/1670 train_time:36056ms step_avg:96.66ms +step:374/1670 train_time:36152ms step_avg:96.66ms +step:375/1670 train_time:36248ms step_avg:96.66ms +step:375/1670 val_loss:3.8203 train_time:36343ms step_avg:96.91ms +step:376/1670 train_time:36367ms step_avg:96.72ms +step:377/1670 train_time:36445ms step_avg:96.67ms +step:378/1670 train_time:36543ms step_avg:96.67ms +step:379/1670 train_time:36638ms step_avg:96.67ms +step:380/1670 train_time:36733ms step_avg:96.66ms +step:381/1670 train_time:36827ms step_avg:96.66ms +step:382/1670 train_time:36922ms step_avg:96.65ms +step:383/1670 train_time:37017ms step_avg:96.65ms +step:384/1670 train_time:37112ms step_avg:96.65ms +step:385/1670 train_time:37206ms step_avg:96.64ms +step:386/1670 train_time:37303ms step_avg:96.64ms +step:387/1670 train_time:37401ms step_avg:96.64ms +step:388/1670 train_time:37498ms step_avg:96.64ms +step:389/1670 train_time:37594ms step_avg:96.64ms +step:390/1670 train_time:37689ms step_avg:96.64ms +step:391/1670 train_time:37785ms step_avg:96.64ms +step:392/1670 train_time:37880ms step_avg:96.63ms +step:393/1670 train_time:37975ms step_avg:96.63ms +step:394/1670 train_time:38070ms step_avg:96.62ms +step:395/1670 train_time:38166ms step_avg:96.62ms +step:396/1670 train_time:38261ms step_avg:96.62ms +step:397/1670 train_time:38357ms step_avg:96.62ms +step:398/1670 train_time:38452ms step_avg:96.61ms +step:399/1670 train_time:38549ms step_avg:96.61ms +step:400/1670 train_time:38645ms step_avg:96.61ms +step:401/1670 train_time:38741ms step_avg:96.61ms +step:402/1670 train_time:38836ms step_avg:96.61ms +step:403/1670 train_time:38931ms step_avg:96.60ms +step:404/1670 train_time:39026ms step_avg:96.60ms +step:405/1670 train_time:39121ms step_avg:96.59ms +step:406/1670 train_time:39216ms step_avg:96.59ms +step:407/1670 train_time:39311ms step_avg:96.59ms +step:408/1670 train_time:39407ms step_avg:96.58ms +step:409/1670 train_time:39504ms step_avg:96.59ms +step:410/1670 train_time:39600ms step_avg:96.59ms +step:411/1670 train_time:39696ms step_avg:96.58ms +step:412/1670 train_time:39792ms step_avg:96.58ms +step:413/1670 train_time:39887ms step_avg:96.58ms +step:414/1670 train_time:39982ms step_avg:96.58ms +step:415/1670 train_time:40079ms step_avg:96.57ms +step:416/1670 train_time:40173ms step_avg:96.57ms +step:417/1670 train_time:40269ms step_avg:96.57ms +step:418/1670 train_time:40364ms step_avg:96.56ms +step:419/1670 train_time:40460ms step_avg:96.56ms +step:420/1670 train_time:40556ms step_avg:96.56ms +step:421/1670 train_time:40651ms step_avg:96.56ms +step:422/1670 train_time:40747ms step_avg:96.56ms +step:423/1670 train_time:40843ms step_avg:96.56ms +step:424/1670 train_time:40939ms step_avg:96.55ms +step:425/1670 train_time:41231ms step_avg:97.01ms +step:426/1670 train_time:41321ms step_avg:97.00ms +step:427/1670 train_time:41414ms step_avg:96.99ms +step:428/1670 train_time:41508ms step_avg:96.98ms +step:429/1670 train_time:41604ms step_avg:96.98ms +step:430/1670 train_time:41698ms step_avg:96.97ms +step:431/1670 train_time:41793ms step_avg:96.97ms +step:432/1670 train_time:41887ms step_avg:96.96ms +step:433/1670 train_time:41981ms step_avg:96.95ms +step:434/1670 train_time:42076ms step_avg:96.95ms +step:435/1670 train_time:42174ms step_avg:96.95ms +step:436/1670 train_time:42271ms step_avg:96.95ms +step:437/1670 train_time:42369ms step_avg:96.95ms +step:438/1670 train_time:42464ms step_avg:96.95ms +step:439/1670 train_time:42560ms step_avg:96.95ms +step:440/1670 train_time:42655ms step_avg:96.94ms +step:441/1670 train_time:42749ms step_avg:96.94ms +step:442/1670 train_time:42844ms step_avg:96.93ms +step:443/1670 train_time:42939ms step_avg:96.93ms +step:444/1670 train_time:43034ms step_avg:96.92ms +step:445/1670 train_time:43130ms step_avg:96.92ms +step:446/1670 train_time:43227ms step_avg:96.92ms +step:447/1670 train_time:43324ms step_avg:96.92ms +step:448/1670 train_time:43421ms step_avg:96.92ms +step:449/1670 train_time:43516ms step_avg:96.92ms +step:450/1670 train_time:43611ms step_avg:96.91ms +step:451/1670 train_time:43706ms step_avg:96.91ms +step:452/1670 train_time:43801ms step_avg:96.90ms +step:453/1670 train_time:43895ms step_avg:96.90ms +step:454/1670 train_time:43990ms step_avg:96.89ms +step:455/1670 train_time:44086ms step_avg:96.89ms +step:456/1670 train_time:44182ms step_avg:96.89ms +step:457/1670 train_time:44280ms step_avg:96.89ms +step:458/1670 train_time:44377ms step_avg:96.89ms +step:459/1670 train_time:44472ms step_avg:96.89ms +step:460/1670 train_time:44568ms step_avg:96.89ms +step:461/1670 train_time:44663ms step_avg:96.88ms +step:462/1670 train_time:44758ms step_avg:96.88ms +step:463/1670 train_time:44853ms step_avg:96.88ms +step:464/1670 train_time:44948ms step_avg:96.87ms +step:465/1670 train_time:45043ms step_avg:96.87ms +step:466/1670 train_time:45140ms step_avg:96.87ms +step:467/1670 train_time:45235ms step_avg:96.86ms +step:468/1670 train_time:45331ms step_avg:96.86ms +step:469/1670 train_time:45428ms step_avg:96.86ms +step:470/1670 train_time:45523ms step_avg:96.86ms +step:471/1670 train_time:45620ms step_avg:96.86ms +step:472/1670 train_time:45714ms step_avg:96.85ms +step:473/1670 train_time:45809ms step_avg:96.85ms +step:474/1670 train_time:45905ms step_avg:96.85ms +step:475/1670 train_time:46001ms step_avg:96.84ms +step:476/1670 train_time:46095ms step_avg:96.84ms +step:477/1670 train_time:46191ms step_avg:96.84ms +step:478/1670 train_time:46287ms step_avg:96.84ms +step:479/1670 train_time:46384ms step_avg:96.84ms +step:480/1670 train_time:46480ms step_avg:96.83ms +step:481/1670 train_time:46575ms step_avg:96.83ms +step:482/1670 train_time:46671ms step_avg:96.83ms +step:483/1670 train_time:46766ms step_avg:96.82ms +step:484/1670 train_time:46860ms step_avg:96.82ms +step:485/1670 train_time:46955ms step_avg:96.82ms +step:486/1670 train_time:47050ms step_avg:96.81ms +step:487/1670 train_time:47146ms step_avg:96.81ms +step:488/1670 train_time:47242ms step_avg:96.81ms +step:489/1670 train_time:47339ms step_avg:96.81ms +step:490/1670 train_time:47435ms step_avg:96.81ms +step:491/1670 train_time:47530ms step_avg:96.80ms +step:492/1670 train_time:47626ms step_avg:96.80ms +step:493/1670 train_time:47722ms step_avg:96.80ms +step:494/1670 train_time:47817ms step_avg:96.80ms +step:495/1670 train_time:47913ms step_avg:96.79ms +step:496/1670 train_time:48007ms step_avg:96.79ms +step:497/1670 train_time:48103ms step_avg:96.79ms +step:498/1670 train_time:48199ms step_avg:96.79ms +step:499/1670 train_time:48295ms step_avg:96.78ms +step:500/1670 train_time:48391ms step_avg:96.78ms +step:500/1670 val_loss:3.7189 train_time:48485ms step_avg:96.97ms +step:501/1670 train_time:48511ms step_avg:96.83ms +step:502/1670 train_time:48587ms step_avg:96.79ms +step:503/1670 train_time:48684ms step_avg:96.79ms +step:504/1670 train_time:48780ms step_avg:96.79ms +step:505/1670 train_time:48875ms step_avg:96.78ms +step:506/1670 train_time:48970ms step_avg:96.78ms +step:507/1670 train_time:49065ms step_avg:96.77ms +step:508/1670 train_time:49159ms step_avg:96.77ms +step:509/1670 train_time:49253ms step_avg:96.77ms +step:510/1670 train_time:49348ms step_avg:96.76ms +step:511/1670 train_time:49443ms step_avg:96.76ms +step:512/1670 train_time:49540ms step_avg:96.76ms +step:513/1670 train_time:49638ms step_avg:96.76ms +step:514/1670 train_time:49736ms step_avg:96.76ms +step:515/1670 train_time:49832ms step_avg:96.76ms +step:516/1670 train_time:49927ms step_avg:96.76ms +step:517/1670 train_time:50022ms step_avg:96.75ms +step:518/1670 train_time:50117ms step_avg:96.75ms +step:519/1670 train_time:50212ms step_avg:96.75ms +step:520/1670 train_time:50307ms step_avg:96.74ms +step:521/1670 train_time:50402ms step_avg:96.74ms +step:522/1670 train_time:50499ms step_avg:96.74ms +step:523/1670 train_time:50595ms step_avg:96.74ms +step:524/1670 train_time:50693ms step_avg:96.74ms +step:525/1670 train_time:50790ms step_avg:96.74ms +step:526/1670 train_time:50886ms step_avg:96.74ms +step:527/1670 train_time:50980ms step_avg:96.74ms +step:528/1670 train_time:51075ms step_avg:96.73ms +step:529/1670 train_time:51170ms step_avg:96.73ms +step:530/1670 train_time:51265ms step_avg:96.73ms +step:531/1670 train_time:51360ms step_avg:96.72ms +step:532/1670 train_time:51456ms step_avg:96.72ms +step:533/1670 train_time:51552ms step_avg:96.72ms +step:534/1670 train_time:51648ms step_avg:96.72ms +step:535/1670 train_time:51745ms step_avg:96.72ms +step:536/1670 train_time:51840ms step_avg:96.72ms +step:537/1670 train_time:51936ms step_avg:96.71ms +step:538/1670 train_time:52031ms step_avg:96.71ms +step:539/1670 train_time:52126ms step_avg:96.71ms +step:540/1670 train_time:52221ms step_avg:96.71ms +step:541/1670 train_time:52316ms step_avg:96.70ms +step:542/1670 train_time:52413ms step_avg:96.70ms +step:543/1670 train_time:52509ms step_avg:96.70ms +step:544/1670 train_time:52605ms step_avg:96.70ms +step:545/1670 train_time:52702ms step_avg:96.70ms +step:546/1670 train_time:52797ms step_avg:96.70ms +step:547/1670 train_time:52893ms step_avg:96.70ms +step:548/1670 train_time:52989ms step_avg:96.70ms +step:549/1670 train_time:53084ms step_avg:96.69ms +step:550/1670 train_time:53179ms step_avg:96.69ms +step:551/1670 train_time:53274ms step_avg:96.69ms +step:552/1670 train_time:53370ms step_avg:96.68ms +step:553/1670 train_time:53465ms step_avg:96.68ms +step:554/1670 train_time:53561ms step_avg:96.68ms +step:555/1670 train_time:53656ms step_avg:96.68ms +step:556/1670 train_time:53753ms step_avg:96.68ms +step:557/1670 train_time:53849ms step_avg:96.68ms +step:558/1670 train_time:53944ms step_avg:96.67ms +step:559/1670 train_time:54040ms step_avg:96.67ms +step:560/1670 train_time:54137ms step_avg:96.67ms +step:561/1670 train_time:54235ms step_avg:96.68ms +step:562/1670 train_time:54332ms step_avg:96.68ms +step:563/1670 train_time:54429ms step_avg:96.68ms +step:564/1670 train_time:54527ms step_avg:96.68ms +step:565/1670 train_time:54624ms step_avg:96.68ms +step:566/1670 train_time:54720ms step_avg:96.68ms +step:567/1670 train_time:54818ms step_avg:96.68ms +step:568/1670 train_time:54915ms step_avg:96.68ms +step:569/1670 train_time:55011ms step_avg:96.68ms +step:570/1670 train_time:55108ms step_avg:96.68ms +step:571/1670 train_time:55204ms step_avg:96.68ms +step:572/1670 train_time:55300ms step_avg:96.68ms +step:573/1670 train_time:55398ms step_avg:96.68ms +step:574/1670 train_time:55495ms step_avg:96.68ms +step:575/1670 train_time:55593ms step_avg:96.68ms +step:576/1670 train_time:55689ms step_avg:96.68ms +step:577/1670 train_time:55787ms step_avg:96.68ms +step:578/1670 train_time:55883ms step_avg:96.68ms +step:579/1670 train_time:55980ms step_avg:96.68ms +step:580/1670 train_time:56077ms step_avg:96.68ms +step:581/1670 train_time:56176ms step_avg:96.69ms +step:582/1670 train_time:56274ms step_avg:96.69ms +step:583/1670 train_time:56372ms step_avg:96.69ms +step:584/1670 train_time:56469ms step_avg:96.69ms +step:585/1670 train_time:56566ms step_avg:96.69ms +step:586/1670 train_time:56663ms step_avg:96.70ms +step:587/1670 train_time:56760ms step_avg:96.69ms +step:588/1670 train_time:56857ms step_avg:96.70ms +step:589/1670 train_time:56954ms step_avg:96.70ms +step:590/1670 train_time:57051ms step_avg:96.70ms +step:591/1670 train_time:57148ms step_avg:96.70ms +step:592/1670 train_time:57246ms step_avg:96.70ms +step:593/1670 train_time:57343ms step_avg:96.70ms +step:594/1670 train_time:57439ms step_avg:96.70ms +step:595/1670 train_time:57536ms step_avg:96.70ms +step:596/1670 train_time:57635ms step_avg:96.70ms +step:597/1670 train_time:57733ms step_avg:96.70ms +step:598/1670 train_time:57830ms step_avg:96.70ms +step:599/1670 train_time:57927ms step_avg:96.71ms +step:600/1670 train_time:58024ms step_avg:96.71ms +step:601/1670 train_time:58120ms step_avg:96.71ms +step:602/1670 train_time:58217ms step_avg:96.71ms +step:603/1670 train_time:58315ms step_avg:96.71ms +step:604/1670 train_time:58413ms step_avg:96.71ms +step:605/1670 train_time:58511ms step_avg:96.71ms +step:606/1670 train_time:58607ms step_avg:96.71ms +step:607/1670 train_time:58704ms step_avg:96.71ms +step:608/1670 train_time:58801ms step_avg:96.71ms +step:609/1670 train_time:58898ms step_avg:96.71ms +step:610/1670 train_time:58995ms step_avg:96.71ms +step:611/1670 train_time:59092ms step_avg:96.71ms +step:612/1670 train_time:59189ms step_avg:96.71ms +step:613/1670 train_time:59286ms step_avg:96.71ms +step:614/1670 train_time:59383ms step_avg:96.72ms +step:615/1670 train_time:59480ms step_avg:96.72ms +step:616/1670 train_time:59578ms step_avg:96.72ms +step:617/1670 train_time:59676ms step_avg:96.72ms +step:618/1670 train_time:59774ms step_avg:96.72ms +step:619/1670 train_time:59871ms step_avg:96.72ms +step:620/1670 train_time:59969ms step_avg:96.72ms +step:621/1670 train_time:60065ms step_avg:96.72ms +step:622/1670 train_time:60162ms step_avg:96.72ms +step:623/1670 train_time:60259ms step_avg:96.72ms +step:624/1670 train_time:60357ms step_avg:96.73ms +step:625/1670 train_time:60453ms step_avg:96.73ms +step:625/1670 val_loss:3.6175 train_time:60550ms step_avg:96.88ms +step:626/1670 train_time:60574ms step_avg:96.76ms +step:627/1670 train_time:60660ms step_avg:96.75ms +step:628/1670 train_time:60756ms step_avg:96.75ms +step:629/1670 train_time:60852ms step_avg:96.74ms +step:630/1670 train_time:60948ms step_avg:96.74ms +step:631/1670 train_time:61044ms step_avg:96.74ms +step:632/1670 train_time:61140ms step_avg:96.74ms +step:633/1670 train_time:61235ms step_avg:96.74ms +step:634/1670 train_time:61331ms step_avg:96.74ms +step:635/1670 train_time:61428ms step_avg:96.74ms +step:636/1670 train_time:61527ms step_avg:96.74ms +step:637/1670 train_time:61627ms step_avg:96.75ms +step:638/1670 train_time:61725ms step_avg:96.75ms +step:639/1670 train_time:62063ms step_avg:97.12ms +step:640/1670 train_time:62165ms step_avg:97.13ms +step:641/1670 train_time:62260ms step_avg:97.13ms +step:642/1670 train_time:62355ms step_avg:97.13ms +step:643/1670 train_time:62451ms step_avg:97.12ms +step:644/1670 train_time:62548ms step_avg:97.12ms +step:645/1670 train_time:62644ms step_avg:97.12ms +step:646/1670 train_time:62740ms step_avg:97.12ms +step:647/1670 train_time:62836ms step_avg:97.12ms +step:648/1670 train_time:62931ms step_avg:97.12ms +step:649/1670 train_time:63032ms step_avg:97.12ms +step:650/1670 train_time:63130ms step_avg:97.12ms +step:651/1670 train_time:63228ms step_avg:97.12ms +step:652/1670 train_time:63326ms step_avg:97.13ms +step:653/1670 train_time:63424ms step_avg:97.13ms +step:654/1670 train_time:63521ms step_avg:97.13ms +step:655/1670 train_time:63617ms step_avg:97.13ms +step:656/1670 train_time:63713ms step_avg:97.12ms +step:657/1670 train_time:63809ms step_avg:97.12ms +step:658/1670 train_time:63905ms step_avg:97.12ms +step:659/1670 train_time:64005ms step_avg:97.12ms +step:660/1670 train_time:64104ms step_avg:97.13ms +step:661/1670 train_time:64202ms step_avg:97.13ms +step:662/1670 train_time:64300ms step_avg:97.13ms +step:663/1670 train_time:64397ms step_avg:97.13ms +step:664/1670 train_time:64494ms step_avg:97.13ms +step:665/1670 train_time:64590ms step_avg:97.13ms +step:666/1670 train_time:64686ms step_avg:97.13ms +step:667/1670 train_time:64783ms step_avg:97.13ms +step:668/1670 train_time:64880ms step_avg:97.13ms +step:669/1670 train_time:64977ms step_avg:97.13ms +step:670/1670 train_time:65075ms step_avg:97.13ms +step:671/1670 train_time:65173ms step_avg:97.13ms +step:672/1670 train_time:65271ms step_avg:97.13ms +step:673/1670 train_time:65368ms step_avg:97.13ms +step:674/1670 train_time:65465ms step_avg:97.13ms +step:675/1670 train_time:65562ms step_avg:97.13ms +step:676/1670 train_time:65660ms step_avg:97.13ms +step:677/1670 train_time:65755ms step_avg:97.13ms +step:678/1670 train_time:65852ms step_avg:97.13ms +step:679/1670 train_time:65949ms step_avg:97.13ms +step:680/1670 train_time:66045ms step_avg:97.13ms +step:681/1670 train_time:66144ms step_avg:97.13ms +step:682/1670 train_time:66243ms step_avg:97.13ms +step:683/1670 train_time:66340ms step_avg:97.13ms +step:684/1670 train_time:66436ms step_avg:97.13ms +step:685/1670 train_time:66534ms step_avg:97.13ms +step:686/1670 train_time:66630ms step_avg:97.13ms +step:687/1670 train_time:66727ms step_avg:97.13ms +step:688/1670 train_time:66824ms step_avg:97.13ms +step:689/1670 train_time:66922ms step_avg:97.13ms +step:690/1670 train_time:67019ms step_avg:97.13ms +step:691/1670 train_time:67117ms step_avg:97.13ms +step:692/1670 train_time:67213ms step_avg:97.13ms +step:693/1670 train_time:67309ms step_avg:97.13ms +step:694/1670 train_time:67407ms step_avg:97.13ms +step:695/1670 train_time:67505ms step_avg:97.13ms +step:696/1670 train_time:67603ms step_avg:97.13ms +step:697/1670 train_time:67699ms step_avg:97.13ms +step:698/1670 train_time:67796ms step_avg:97.13ms +step:699/1670 train_time:67892ms step_avg:97.13ms +step:700/1670 train_time:67990ms step_avg:97.13ms +step:701/1670 train_time:68087ms step_avg:97.13ms +step:702/1670 train_time:68185ms step_avg:97.13ms +step:703/1670 train_time:68283ms step_avg:97.13ms +step:704/1670 train_time:68381ms step_avg:97.13ms +step:705/1670 train_time:68478ms step_avg:97.13ms +step:706/1670 train_time:68575ms step_avg:97.13ms +step:707/1670 train_time:68672ms step_avg:97.13ms +step:708/1670 train_time:68767ms step_avg:97.13ms +step:709/1670 train_time:68864ms step_avg:97.13ms +step:710/1670 train_time:68962ms step_avg:97.13ms +step:711/1670 train_time:69061ms step_avg:97.13ms +step:712/1670 train_time:69159ms step_avg:97.13ms +step:713/1670 train_time:69255ms step_avg:97.13ms +step:714/1670 train_time:69352ms step_avg:97.13ms +step:715/1670 train_time:69448ms step_avg:97.13ms +step:716/1670 train_time:69545ms step_avg:97.13ms +step:717/1670 train_time:69644ms step_avg:97.13ms +step:718/1670 train_time:69740ms step_avg:97.13ms +step:719/1670 train_time:69836ms step_avg:97.13ms +step:720/1670 train_time:69933ms step_avg:97.13ms +step:721/1670 train_time:70029ms step_avg:97.13ms +step:722/1670 train_time:70127ms step_avg:97.13ms +step:723/1670 train_time:70224ms step_avg:97.13ms +step:724/1670 train_time:70322ms step_avg:97.13ms +step:725/1670 train_time:70420ms step_avg:97.13ms +step:726/1670 train_time:70516ms step_avg:97.13ms +step:727/1670 train_time:70613ms step_avg:97.13ms +step:728/1670 train_time:70709ms step_avg:97.13ms +step:729/1670 train_time:70805ms step_avg:97.13ms +step:730/1670 train_time:70902ms step_avg:97.13ms +step:731/1670 train_time:70999ms step_avg:97.13ms +step:732/1670 train_time:71096ms step_avg:97.13ms +step:733/1670 train_time:71194ms step_avg:97.13ms +step:734/1670 train_time:71291ms step_avg:97.13ms +step:735/1670 train_time:71388ms step_avg:97.13ms +step:736/1670 train_time:71486ms step_avg:97.13ms +step:737/1670 train_time:71583ms step_avg:97.13ms +step:738/1670 train_time:71680ms step_avg:97.13ms +step:739/1670 train_time:71777ms step_avg:97.13ms +step:740/1670 train_time:71873ms step_avg:97.13ms +step:741/1670 train_time:71970ms step_avg:97.12ms +step:742/1670 train_time:72067ms step_avg:97.13ms +step:743/1670 train_time:72165ms step_avg:97.13ms +step:744/1670 train_time:72263ms step_avg:97.13ms +step:745/1670 train_time:72362ms step_avg:97.13ms +step:746/1670 train_time:72458ms step_avg:97.13ms +step:747/1670 train_time:72555ms step_avg:97.13ms +step:748/1670 train_time:72652ms step_avg:97.13ms +step:749/1670 train_time:72749ms step_avg:97.13ms +step:750/1670 train_time:72846ms step_avg:97.13ms +step:750/1670 val_loss:3.5636 train_time:72942ms step_avg:97.26ms +step:751/1670 train_time:72967ms step_avg:97.16ms +step:752/1670 train_time:73047ms step_avg:97.14ms +step:753/1670 train_time:73145ms step_avg:97.14ms +step:754/1670 train_time:73242ms step_avg:97.14ms +step:755/1670 train_time:73339ms step_avg:97.14ms +step:756/1670 train_time:73435ms step_avg:97.14ms +step:757/1670 train_time:73531ms step_avg:97.13ms +step:758/1670 train_time:73627ms step_avg:97.13ms +step:759/1670 train_time:73723ms step_avg:97.13ms +step:760/1670 train_time:73820ms step_avg:97.13ms +step:761/1670 train_time:73920ms step_avg:97.13ms +step:762/1670 train_time:74019ms step_avg:97.14ms +step:763/1670 train_time:74118ms step_avg:97.14ms +step:764/1670 train_time:74215ms step_avg:97.14ms +step:765/1670 train_time:74312ms step_avg:97.14ms +step:766/1670 train_time:74409ms step_avg:97.14ms +step:767/1670 train_time:74505ms step_avg:97.14ms +step:768/1670 train_time:74602ms step_avg:97.14ms +step:769/1670 train_time:74699ms step_avg:97.14ms +step:770/1670 train_time:74795ms step_avg:97.14ms +step:771/1670 train_time:74892ms step_avg:97.14ms +step:772/1670 train_time:74990ms step_avg:97.14ms +step:773/1670 train_time:75087ms step_avg:97.14ms +step:774/1670 train_time:75185ms step_avg:97.14ms +step:775/1670 train_time:75282ms step_avg:97.14ms +step:776/1670 train_time:75380ms step_avg:97.14ms +step:777/1670 train_time:75478ms step_avg:97.14ms +step:778/1670 train_time:75574ms step_avg:97.14ms +step:779/1670 train_time:75670ms step_avg:97.14ms +step:780/1670 train_time:75767ms step_avg:97.14ms +step:781/1670 train_time:75863ms step_avg:97.14ms +step:782/1670 train_time:75961ms step_avg:97.14ms +step:783/1670 train_time:76061ms step_avg:97.14ms +step:784/1670 train_time:76158ms step_avg:97.14ms +step:785/1670 train_time:76255ms step_avg:97.14ms +step:786/1670 train_time:76352ms step_avg:97.14ms +step:787/1670 train_time:76448ms step_avg:97.14ms +step:788/1670 train_time:76545ms step_avg:97.14ms +step:789/1670 train_time:76643ms step_avg:97.14ms +step:790/1670 train_time:76741ms step_avg:97.14ms +step:791/1670 train_time:76838ms step_avg:97.14ms +step:792/1670 train_time:76935ms step_avg:97.14ms +step:793/1670 train_time:77032ms step_avg:97.14ms +step:794/1670 train_time:77130ms step_avg:97.14ms +step:795/1670 train_time:77227ms step_avg:97.14ms +step:796/1670 train_time:77324ms step_avg:97.14ms +step:797/1670 train_time:77421ms step_avg:97.14ms +step:798/1670 train_time:77518ms step_avg:97.14ms +step:799/1670 train_time:77615ms step_avg:97.14ms +step:800/1670 train_time:77711ms step_avg:97.14ms +step:801/1670 train_time:77808ms step_avg:97.14ms +step:802/1670 train_time:77906ms step_avg:97.14ms +step:803/1670 train_time:78003ms step_avg:97.14ms +step:804/1670 train_time:78101ms step_avg:97.14ms +step:805/1670 train_time:78199ms step_avg:97.14ms +step:806/1670 train_time:78296ms step_avg:97.14ms +step:807/1670 train_time:78392ms step_avg:97.14ms +step:808/1670 train_time:78489ms step_avg:97.14ms +step:809/1670 train_time:78585ms step_avg:97.14ms +step:810/1670 train_time:78683ms step_avg:97.14ms +step:811/1670 train_time:78780ms step_avg:97.14ms +step:812/1670 train_time:78878ms step_avg:97.14ms +step:813/1670 train_time:78975ms step_avg:97.14ms +step:814/1670 train_time:79072ms step_avg:97.14ms +step:815/1670 train_time:79168ms step_avg:97.14ms +step:816/1670 train_time:79266ms step_avg:97.14ms +step:817/1670 train_time:79363ms step_avg:97.14ms +step:818/1670 train_time:79460ms step_avg:97.14ms +step:819/1670 train_time:79557ms step_avg:97.14ms +step:820/1670 train_time:79654ms step_avg:97.14ms +step:821/1670 train_time:79751ms step_avg:97.14ms +step:822/1670 train_time:79848ms step_avg:97.14ms +step:823/1670 train_time:79945ms step_avg:97.14ms +step:824/1670 train_time:80043ms step_avg:97.14ms +step:825/1670 train_time:80141ms step_avg:97.14ms +step:826/1670 train_time:80238ms step_avg:97.14ms +step:827/1670 train_time:80336ms step_avg:97.14ms +step:828/1670 train_time:80432ms step_avg:97.14ms +step:829/1670 train_time:80528ms step_avg:97.14ms +step:830/1670 train_time:80625ms step_avg:97.14ms +step:831/1670 train_time:80722ms step_avg:97.14ms +step:832/1670 train_time:80820ms step_avg:97.14ms +step:833/1670 train_time:80917ms step_avg:97.14ms +step:834/1670 train_time:81014ms step_avg:97.14ms +step:835/1670 train_time:81112ms step_avg:97.14ms +step:836/1670 train_time:81208ms step_avg:97.14ms +step:837/1670 train_time:81306ms step_avg:97.14ms +step:838/1670 train_time:81404ms step_avg:97.14ms +step:839/1670 train_time:81501ms step_avg:97.14ms +step:840/1670 train_time:81598ms step_avg:97.14ms +step:841/1670 train_time:81696ms step_avg:97.14ms +step:842/1670 train_time:81792ms step_avg:97.14ms +step:843/1670 train_time:81888ms step_avg:97.14ms +step:844/1670 train_time:81984ms step_avg:97.14ms +step:845/1670 train_time:82082ms step_avg:97.14ms +step:846/1670 train_time:82180ms step_avg:97.14ms +step:847/1670 train_time:82278ms step_avg:97.14ms +step:848/1670 train_time:82376ms step_avg:97.14ms +step:849/1670 train_time:82472ms step_avg:97.14ms +step:850/1670 train_time:82570ms step_avg:97.14ms +step:851/1670 train_time:82846ms step_avg:97.35ms +step:852/1670 train_time:83015ms step_avg:97.44ms +step:853/1670 train_time:83109ms step_avg:97.43ms +step:854/1670 train_time:83205ms step_avg:97.43ms +step:855/1670 train_time:83301ms step_avg:97.43ms +step:856/1670 train_time:83397ms step_avg:97.43ms +step:857/1670 train_time:83493ms step_avg:97.42ms +step:858/1670 train_time:83588ms step_avg:97.42ms +step:859/1670 train_time:83684ms step_avg:97.42ms +step:860/1670 train_time:83780ms step_avg:97.42ms +step:861/1670 train_time:83878ms step_avg:97.42ms +step:862/1670 train_time:83979ms step_avg:97.42ms +step:863/1670 train_time:84079ms step_avg:97.43ms +step:864/1670 train_time:84177ms step_avg:97.43ms +step:865/1670 train_time:84273ms step_avg:97.43ms +step:866/1670 train_time:84370ms step_avg:97.42ms +step:867/1670 train_time:84466ms step_avg:97.42ms +step:868/1670 train_time:84562ms step_avg:97.42ms +step:869/1670 train_time:84659ms step_avg:97.42ms +step:870/1670 train_time:84756ms step_avg:97.42ms +step:871/1670 train_time:84852ms step_avg:97.42ms +step:872/1670 train_time:84951ms step_avg:97.42ms +step:873/1670 train_time:85050ms step_avg:97.42ms +step:874/1670 train_time:85147ms step_avg:97.42ms +step:875/1670 train_time:85244ms step_avg:97.42ms +step:875/1670 val_loss:3.5198 train_time:85341ms step_avg:97.53ms +step:876/1670 train_time:85364ms step_avg:97.45ms +step:877/1670 train_time:85447ms step_avg:97.43ms +step:878/1670 train_time:85548ms step_avg:97.44ms +step:879/1670 train_time:85647ms step_avg:97.44ms +step:880/1670 train_time:85743ms step_avg:97.43ms +step:881/1670 train_time:85839ms step_avg:97.43ms +step:882/1670 train_time:85935ms step_avg:97.43ms +step:883/1670 train_time:86030ms step_avg:97.43ms +step:884/1670 train_time:86126ms step_avg:97.43ms +step:885/1670 train_time:86222ms step_avg:97.43ms +step:886/1670 train_time:86320ms step_avg:97.43ms +step:887/1670 train_time:86422ms step_avg:97.43ms +step:888/1670 train_time:86523ms step_avg:97.44ms +step:889/1670 train_time:86623ms step_avg:97.44ms +step:890/1670 train_time:86720ms step_avg:97.44ms +step:891/1670 train_time:86817ms step_avg:97.44ms +step:892/1670 train_time:86913ms step_avg:97.44ms +step:893/1670 train_time:87008ms step_avg:97.43ms +step:894/1670 train_time:87105ms step_avg:97.43ms +step:895/1670 train_time:87201ms step_avg:97.43ms +step:896/1670 train_time:87298ms step_avg:97.43ms +step:897/1670 train_time:87396ms step_avg:97.43ms +step:898/1670 train_time:87494ms step_avg:97.43ms +step:899/1670 train_time:87594ms step_avg:97.44ms +step:900/1670 train_time:87692ms step_avg:97.44ms +step:901/1670 train_time:87788ms step_avg:97.43ms +step:902/1670 train_time:87885ms step_avg:97.43ms +step:903/1670 train_time:87982ms step_avg:97.43ms +step:904/1670 train_time:88079ms step_avg:97.43ms +step:905/1670 train_time:88175ms step_avg:97.43ms +step:906/1670 train_time:88271ms step_avg:97.43ms +step:907/1670 train_time:88368ms step_avg:97.43ms +step:908/1670 train_time:88465ms step_avg:97.43ms +step:909/1670 train_time:88564ms step_avg:97.43ms +step:910/1670 train_time:88663ms step_avg:97.43ms +step:911/1670 train_time:88762ms step_avg:97.43ms +step:912/1670 train_time:88860ms step_avg:97.43ms +step:913/1670 train_time:88957ms step_avg:97.43ms +step:914/1670 train_time:89053ms step_avg:97.43ms +step:915/1670 train_time:89149ms step_avg:97.43ms +step:916/1670 train_time:89245ms step_avg:97.43ms +step:917/1670 train_time:89342ms step_avg:97.43ms +step:918/1670 train_time:89439ms step_avg:97.43ms +step:919/1670 train_time:89538ms step_avg:97.43ms +step:920/1670 train_time:89635ms step_avg:97.43ms +step:921/1670 train_time:89733ms step_avg:97.43ms +step:922/1670 train_time:89830ms step_avg:97.43ms +step:923/1670 train_time:89927ms step_avg:97.43ms +step:924/1670 train_time:90026ms step_avg:97.43ms +step:925/1670 train_time:90123ms step_avg:97.43ms +step:926/1670 train_time:90219ms step_avg:97.43ms +step:927/1670 train_time:90316ms step_avg:97.43ms +step:928/1670 train_time:90412ms step_avg:97.43ms +step:929/1670 train_time:90509ms step_avg:97.43ms +step:930/1670 train_time:90607ms step_avg:97.43ms +step:931/1670 train_time:90705ms step_avg:97.43ms +step:932/1670 train_time:90803ms step_avg:97.43ms +step:933/1670 train_time:90901ms step_avg:97.43ms +step:934/1670 train_time:90998ms step_avg:97.43ms +step:935/1670 train_time:91094ms step_avg:97.43ms +step:936/1670 train_time:91191ms step_avg:97.43ms +step:937/1670 train_time:91287ms step_avg:97.42ms +step:938/1670 train_time:91384ms step_avg:97.42ms +step:939/1670 train_time:91482ms step_avg:97.42ms +step:940/1670 train_time:91580ms step_avg:97.43ms +step:941/1670 train_time:91677ms step_avg:97.42ms +step:942/1670 train_time:91774ms step_avg:97.42ms +step:943/1670 train_time:91871ms step_avg:97.42ms +step:944/1670 train_time:91967ms step_avg:97.42ms +step:945/1670 train_time:92064ms step_avg:97.42ms +step:946/1670 train_time:92162ms step_avg:97.42ms +step:947/1670 train_time:92260ms step_avg:97.42ms +step:948/1670 train_time:92356ms step_avg:97.42ms +step:949/1670 train_time:92452ms step_avg:97.42ms +step:950/1670 train_time:92550ms step_avg:97.42ms +step:951/1670 train_time:92647ms step_avg:97.42ms +step:952/1670 train_time:92746ms step_avg:97.42ms +step:953/1670 train_time:92843ms step_avg:97.42ms +step:954/1670 train_time:92940ms step_avg:97.42ms +step:955/1670 train_time:93038ms step_avg:97.42ms +step:956/1670 train_time:93135ms step_avg:97.42ms +step:957/1670 train_time:93232ms step_avg:97.42ms +step:958/1670 train_time:93328ms step_avg:97.42ms +step:959/1670 train_time:93425ms step_avg:97.42ms +step:960/1670 train_time:93522ms step_avg:97.42ms +step:961/1670 train_time:93620ms step_avg:97.42ms +step:962/1670 train_time:93716ms step_avg:97.42ms +step:963/1670 train_time:93813ms step_avg:97.42ms +step:964/1670 train_time:93911ms step_avg:97.42ms +step:965/1670 train_time:94008ms step_avg:97.42ms +step:966/1670 train_time:94105ms step_avg:97.42ms +step:967/1670 train_time:94202ms step_avg:97.42ms +step:968/1670 train_time:94300ms step_avg:97.42ms +step:969/1670 train_time:94398ms step_avg:97.42ms +step:970/1670 train_time:94495ms step_avg:97.42ms +step:971/1670 train_time:94591ms step_avg:97.42ms +step:972/1670 train_time:94687ms step_avg:97.42ms +step:973/1670 train_time:94786ms step_avg:97.42ms +step:974/1670 train_time:94884ms step_avg:97.42ms +step:975/1670 train_time:94981ms step_avg:97.42ms +step:976/1670 train_time:95078ms step_avg:97.42ms +step:977/1670 train_time:95174ms step_avg:97.41ms +step:978/1670 train_time:95272ms step_avg:97.41ms +step:979/1670 train_time:95368ms step_avg:97.41ms +step:980/1670 train_time:95465ms step_avg:97.41ms +step:981/1670 train_time:95562ms step_avg:97.41ms +step:982/1670 train_time:95660ms step_avg:97.41ms +step:983/1670 train_time:95758ms step_avg:97.41ms +step:984/1670 train_time:95854ms step_avg:97.41ms +step:985/1670 train_time:95950ms step_avg:97.41ms +step:986/1670 train_time:96047ms step_avg:97.41ms +step:987/1670 train_time:96144ms step_avg:97.41ms +step:988/1670 train_time:96241ms step_avg:97.41ms +step:989/1670 train_time:96338ms step_avg:97.41ms +step:990/1670 train_time:96436ms step_avg:97.41ms +step:991/1670 train_time:96533ms step_avg:97.41ms +step:992/1670 train_time:96631ms step_avg:97.41ms +step:993/1670 train_time:96728ms step_avg:97.41ms +step:994/1670 train_time:96826ms step_avg:97.41ms +step:995/1670 train_time:96923ms step_avg:97.41ms +step:996/1670 train_time:97021ms step_avg:97.41ms +step:997/1670 train_time:97119ms step_avg:97.41ms +step:998/1670 train_time:97216ms step_avg:97.41ms +step:999/1670 train_time:97313ms step_avg:97.41ms +step:1000/1670 train_time:97409ms step_avg:97.41ms +step:1000/1670 val_loss:3.4791 train_time:97506ms step_avg:97.51ms +step:1001/1670 train_time:97529ms step_avg:97.43ms +step:1002/1670 train_time:97611ms step_avg:97.42ms +step:1003/1670 train_time:97708ms step_avg:97.42ms +step:1004/1670 train_time:97806ms step_avg:97.42ms +step:1005/1670 train_time:97903ms step_avg:97.42ms +step:1006/1670 train_time:97999ms step_avg:97.41ms +step:1007/1670 train_time:98095ms step_avg:97.41ms +step:1008/1670 train_time:98190ms step_avg:97.41ms +step:1009/1670 train_time:98287ms step_avg:97.41ms +step:1010/1670 train_time:98384ms step_avg:97.41ms +step:1011/1670 train_time:98481ms step_avg:97.41ms +step:1012/1670 train_time:98582ms step_avg:97.41ms +step:1013/1670 train_time:98681ms step_avg:97.41ms +step:1014/1670 train_time:98778ms step_avg:97.41ms +step:1015/1670 train_time:98875ms step_avg:97.41ms +step:1016/1670 train_time:98971ms step_avg:97.41ms +step:1017/1670 train_time:99067ms step_avg:97.41ms +step:1018/1670 train_time:99164ms step_avg:97.41ms +step:1019/1670 train_time:99261ms step_avg:97.41ms +step:1020/1670 train_time:99357ms step_avg:97.41ms +step:1021/1670 train_time:99455ms step_avg:97.41ms +step:1022/1670 train_time:99553ms step_avg:97.41ms +step:1023/1670 train_time:99650ms step_avg:97.41ms +step:1024/1670 train_time:99748ms step_avg:97.41ms +step:1025/1670 train_time:99846ms step_avg:97.41ms +step:1026/1670 train_time:99945ms step_avg:97.41ms +step:1027/1670 train_time:100041ms step_avg:97.41ms +step:1028/1670 train_time:100138ms step_avg:97.41ms +step:1029/1670 train_time:100234ms step_avg:97.41ms +step:1030/1670 train_time:100330ms step_avg:97.41ms +step:1031/1670 train_time:100427ms step_avg:97.41ms +step:1032/1670 train_time:100526ms step_avg:97.41ms +step:1033/1670 train_time:100625ms step_avg:97.41ms +step:1034/1670 train_time:100724ms step_avg:97.41ms +step:1035/1670 train_time:100822ms step_avg:97.41ms +step:1036/1670 train_time:100919ms step_avg:97.41ms +step:1037/1670 train_time:101016ms step_avg:97.41ms +step:1038/1670 train_time:101112ms step_avg:97.41ms +step:1039/1670 train_time:101208ms step_avg:97.41ms +step:1040/1670 train_time:101305ms step_avg:97.41ms +step:1041/1670 train_time:101402ms step_avg:97.41ms +step:1042/1670 train_time:101499ms step_avg:97.41ms +step:1043/1670 train_time:101596ms step_avg:97.41ms +step:1044/1670 train_time:101694ms step_avg:97.41ms +step:1045/1670 train_time:101791ms step_avg:97.41ms +step:1046/1670 train_time:101887ms step_avg:97.41ms +step:1047/1670 train_time:101985ms step_avg:97.41ms +step:1048/1670 train_time:102083ms step_avg:97.41ms +step:1049/1670 train_time:102181ms step_avg:97.41ms +step:1050/1670 train_time:102277ms step_avg:97.41ms +step:1051/1670 train_time:102374ms step_avg:97.41ms +step:1052/1670 train_time:102470ms step_avg:97.40ms +step:1053/1670 train_time:102567ms step_avg:97.40ms +step:1054/1670 train_time:102663ms step_avg:97.40ms +step:1055/1670 train_time:102762ms step_avg:97.40ms +step:1056/1670 train_time:102859ms step_avg:97.40ms +step:1057/1670 train_time:102956ms step_avg:97.40ms +step:1058/1670 train_time:103053ms step_avg:97.40ms +step:1059/1670 train_time:103150ms step_avg:97.40ms +step:1060/1670 train_time:103248ms step_avg:97.40ms +step:1061/1670 train_time:103344ms step_avg:97.40ms +step:1062/1670 train_time:103619ms step_avg:97.57ms +step:1063/1670 train_time:103693ms step_avg:97.55ms +step:1064/1670 train_time:103789ms step_avg:97.55ms +step:1065/1670 train_time:103885ms step_avg:97.54ms +step:1066/1670 train_time:103980ms step_avg:97.54ms +step:1067/1670 train_time:104076ms step_avg:97.54ms +step:1068/1670 train_time:104172ms step_avg:97.54ms +step:1069/1670 train_time:104268ms step_avg:97.54ms +step:1070/1670 train_time:104364ms step_avg:97.54ms +step:1071/1670 train_time:104460ms step_avg:97.53ms +step:1072/1670 train_time:104561ms step_avg:97.54ms +step:1073/1670 train_time:104661ms step_avg:97.54ms +step:1074/1670 train_time:104759ms step_avg:97.54ms +step:1075/1670 train_time:104856ms step_avg:97.54ms +step:1076/1670 train_time:104952ms step_avg:97.54ms +step:1077/1670 train_time:105047ms step_avg:97.54ms +step:1078/1670 train_time:105143ms step_avg:97.54ms +step:1079/1670 train_time:105240ms step_avg:97.53ms +step:1080/1670 train_time:105336ms step_avg:97.53ms +step:1081/1670 train_time:105432ms step_avg:97.53ms +step:1082/1670 train_time:105530ms step_avg:97.53ms +step:1083/1670 train_time:105628ms step_avg:97.53ms +step:1084/1670 train_time:105726ms step_avg:97.53ms +step:1085/1670 train_time:105824ms step_avg:97.53ms +step:1086/1670 train_time:105921ms step_avg:97.53ms +step:1087/1670 train_time:106019ms step_avg:97.53ms +step:1088/1670 train_time:106115ms step_avg:97.53ms +step:1089/1670 train_time:106211ms step_avg:97.53ms +step:1090/1670 train_time:106308ms step_avg:97.53ms +step:1091/1670 train_time:106404ms step_avg:97.53ms +step:1092/1670 train_time:106503ms step_avg:97.53ms +step:1093/1670 train_time:106600ms step_avg:97.53ms +step:1094/1670 train_time:106697ms step_avg:97.53ms +step:1095/1670 train_time:106794ms step_avg:97.53ms +step:1096/1670 train_time:106891ms step_avg:97.53ms +step:1097/1670 train_time:106988ms step_avg:97.53ms +step:1098/1670 train_time:107085ms step_avg:97.53ms +step:1099/1670 train_time:107182ms step_avg:97.53ms +step:1100/1670 train_time:107279ms step_avg:97.53ms +step:1101/1670 train_time:107376ms step_avg:97.53ms +step:1102/1670 train_time:107472ms step_avg:97.52ms +step:1103/1670 train_time:107568ms step_avg:97.52ms +step:1104/1670 train_time:107666ms step_avg:97.52ms +step:1105/1670 train_time:107764ms step_avg:97.52ms +step:1106/1670 train_time:107862ms step_avg:97.52ms +step:1107/1670 train_time:107959ms step_avg:97.52ms +step:1108/1670 train_time:108055ms step_avg:97.52ms +step:1109/1670 train_time:108152ms step_avg:97.52ms +step:1110/1670 train_time:108249ms step_avg:97.52ms +step:1111/1670 train_time:108346ms step_avg:97.52ms +step:1112/1670 train_time:108443ms step_avg:97.52ms +step:1113/1670 train_time:108540ms step_avg:97.52ms +step:1114/1670 train_time:108638ms step_avg:97.52ms +step:1115/1670 train_time:108736ms step_avg:97.52ms +step:1116/1670 train_time:108833ms step_avg:97.52ms +step:1117/1670 train_time:108931ms step_avg:97.52ms +step:1118/1670 train_time:109028ms step_avg:97.52ms +step:1119/1670 train_time:109126ms step_avg:97.52ms +step:1120/1670 train_time:109223ms step_avg:97.52ms +step:1121/1670 train_time:109322ms step_avg:97.52ms +step:1122/1670 train_time:109420ms step_avg:97.52ms +step:1123/1670 train_time:109516ms step_avg:97.52ms +step:1124/1670 train_time:109613ms step_avg:97.52ms +step:1125/1670 train_time:109711ms step_avg:97.52ms +step:1125/1670 val_loss:3.4256 train_time:109807ms step_avg:97.61ms +step:1126/1670 train_time:109830ms step_avg:97.54ms +step:1127/1670 train_time:109912ms step_avg:97.53ms +step:1128/1670 train_time:110013ms step_avg:97.53ms +step:1129/1670 train_time:110111ms step_avg:97.53ms +step:1130/1670 train_time:110209ms step_avg:97.53ms +step:1131/1670 train_time:110305ms step_avg:97.53ms +step:1132/1670 train_time:110402ms step_avg:97.53ms +step:1133/1670 train_time:110499ms step_avg:97.53ms +step:1134/1670 train_time:110595ms step_avg:97.53ms +step:1135/1670 train_time:110692ms step_avg:97.53ms +step:1136/1670 train_time:110793ms step_avg:97.53ms +step:1137/1670 train_time:110892ms step_avg:97.53ms +step:1138/1670 train_time:110993ms step_avg:97.53ms +step:1139/1670 train_time:111092ms step_avg:97.53ms +step:1140/1670 train_time:111189ms step_avg:97.53ms +step:1141/1670 train_time:111285ms step_avg:97.53ms +step:1142/1670 train_time:111382ms step_avg:97.53ms +step:1143/1670 train_time:111479ms step_avg:97.53ms +step:1144/1670 train_time:111576ms step_avg:97.53ms +step:1145/1670 train_time:111673ms step_avg:97.53ms +step:1146/1670 train_time:111773ms step_avg:97.53ms +step:1147/1670 train_time:111872ms step_avg:97.53ms +step:1148/1670 train_time:111971ms step_avg:97.54ms +step:1149/1670 train_time:112068ms step_avg:97.53ms +step:1150/1670 train_time:112165ms step_avg:97.54ms +step:1151/1670 train_time:112262ms step_avg:97.53ms +step:1152/1670 train_time:112359ms step_avg:97.53ms +step:1153/1670 train_time:112456ms step_avg:97.53ms +step:1154/1670 train_time:112553ms step_avg:97.53ms +step:1155/1670 train_time:112650ms step_avg:97.53ms +step:1156/1670 train_time:112748ms step_avg:97.53ms +step:1157/1670 train_time:112846ms step_avg:97.53ms +step:1158/1670 train_time:112945ms step_avg:97.53ms +step:1159/1670 train_time:113043ms step_avg:97.54ms +step:1160/1670 train_time:113143ms step_avg:97.54ms +step:1161/1670 train_time:113240ms step_avg:97.54ms +step:1162/1670 train_time:113337ms step_avg:97.54ms +step:1163/1670 train_time:113434ms step_avg:97.54ms +step:1164/1670 train_time:113531ms step_avg:97.54ms +step:1165/1670 train_time:113628ms step_avg:97.53ms +step:1166/1670 train_time:113725ms step_avg:97.53ms +step:1167/1670 train_time:113823ms step_avg:97.53ms +step:1168/1670 train_time:113923ms step_avg:97.54ms +step:1169/1670 train_time:114023ms step_avg:97.54ms +step:1170/1670 train_time:114122ms step_avg:97.54ms +step:1171/1670 train_time:114221ms step_avg:97.54ms +step:1172/1670 train_time:114318ms step_avg:97.54ms +step:1173/1670 train_time:114415ms step_avg:97.54ms +step:1174/1670 train_time:114513ms step_avg:97.54ms +step:1175/1670 train_time:114610ms step_avg:97.54ms +step:1176/1670 train_time:114707ms step_avg:97.54ms +step:1177/1670 train_time:114805ms step_avg:97.54ms +step:1178/1670 train_time:114903ms step_avg:97.54ms +step:1179/1670 train_time:115002ms step_avg:97.54ms +step:1180/1670 train_time:115100ms step_avg:97.54ms +step:1181/1670 train_time:115198ms step_avg:97.54ms +step:1182/1670 train_time:115297ms step_avg:97.54ms +step:1183/1670 train_time:115394ms step_avg:97.54ms +step:1184/1670 train_time:115492ms step_avg:97.54ms +step:1185/1670 train_time:115589ms step_avg:97.54ms +step:1186/1670 train_time:115687ms step_avg:97.54ms +step:1187/1670 train_time:115784ms step_avg:97.54ms +step:1188/1670 train_time:115881ms step_avg:97.54ms +step:1189/1670 train_time:115980ms step_avg:97.54ms +step:1190/1670 train_time:116078ms step_avg:97.54ms +step:1191/1670 train_time:116176ms step_avg:97.55ms +step:1192/1670 train_time:116275ms step_avg:97.55ms +step:1193/1670 train_time:116372ms step_avg:97.55ms +step:1194/1670 train_time:116470ms step_avg:97.55ms +step:1195/1670 train_time:116567ms step_avg:97.55ms +step:1196/1670 train_time:116664ms step_avg:97.55ms +step:1197/1670 train_time:116762ms step_avg:97.55ms +step:1198/1670 train_time:116861ms step_avg:97.55ms +step:1199/1670 train_time:116960ms step_avg:97.55ms +step:1200/1670 train_time:117058ms step_avg:97.55ms +step:1201/1670 train_time:117156ms step_avg:97.55ms +step:1202/1670 train_time:117254ms step_avg:97.55ms +step:1203/1670 train_time:117353ms step_avg:97.55ms +step:1204/1670 train_time:117450ms step_avg:97.55ms +step:1205/1670 train_time:117547ms step_avg:97.55ms +step:1206/1670 train_time:117644ms step_avg:97.55ms +step:1207/1670 train_time:117741ms step_avg:97.55ms +step:1208/1670 train_time:117839ms step_avg:97.55ms +step:1209/1670 train_time:117937ms step_avg:97.55ms +step:1210/1670 train_time:118036ms step_avg:97.55ms +step:1211/1670 train_time:118133ms step_avg:97.55ms +step:1212/1670 train_time:118230ms step_avg:97.55ms +step:1213/1670 train_time:118328ms step_avg:97.55ms +step:1214/1670 train_time:118426ms step_avg:97.55ms +step:1215/1670 train_time:118524ms step_avg:97.55ms +step:1216/1670 train_time:118623ms step_avg:97.55ms +step:1217/1670 train_time:118720ms step_avg:97.55ms +step:1218/1670 train_time:118818ms step_avg:97.55ms +step:1219/1670 train_time:118915ms step_avg:97.55ms +step:1220/1670 train_time:119012ms step_avg:97.55ms +step:1221/1670 train_time:119110ms step_avg:97.55ms +step:1222/1670 train_time:119207ms step_avg:97.55ms +step:1223/1670 train_time:119305ms step_avg:97.55ms +step:1224/1670 train_time:119403ms step_avg:97.55ms +step:1225/1670 train_time:119502ms step_avg:97.55ms +step:1226/1670 train_time:119600ms step_avg:97.55ms +step:1227/1670 train_time:119697ms step_avg:97.55ms +step:1228/1670 train_time:119794ms step_avg:97.55ms +step:1229/1670 train_time:119893ms step_avg:97.55ms +step:1230/1670 train_time:119990ms step_avg:97.55ms +step:1231/1670 train_time:120087ms step_avg:97.55ms +step:1232/1670 train_time:120185ms step_avg:97.55ms +step:1233/1670 train_time:120283ms step_avg:97.55ms +step:1234/1670 train_time:120381ms step_avg:97.55ms +step:1235/1670 train_time:120479ms step_avg:97.55ms +step:1236/1670 train_time:120578ms step_avg:97.56ms +step:1237/1670 train_time:120676ms step_avg:97.56ms +step:1238/1670 train_time:120773ms step_avg:97.56ms +step:1239/1670 train_time:120870ms step_avg:97.55ms +step:1240/1670 train_time:120968ms step_avg:97.55ms +step:1241/1670 train_time:121065ms step_avg:97.55ms +step:1242/1670 train_time:121164ms step_avg:97.56ms +step:1243/1670 train_time:121262ms step_avg:97.56ms +step:1244/1670 train_time:121360ms step_avg:97.56ms +step:1245/1670 train_time:121459ms step_avg:97.56ms +step:1246/1670 train_time:121557ms step_avg:97.56ms +step:1247/1670 train_time:121655ms step_avg:97.56ms +step:1248/1670 train_time:121753ms step_avg:97.56ms +step:1249/1670 train_time:121851ms step_avg:97.56ms +step:1250/1670 train_time:121948ms step_avg:97.56ms +step:1250/1670 val_loss:3.3835 train_time:122044ms step_avg:97.63ms +step:1251/1670 train_time:122067ms step_avg:97.58ms +step:1252/1670 train_time:122149ms step_avg:97.56ms +step:1253/1670 train_time:122249ms step_avg:97.57ms +step:1254/1670 train_time:122348ms step_avg:97.57ms +step:1255/1670 train_time:122444ms step_avg:97.57ms +step:1256/1670 train_time:122541ms step_avg:97.56ms +step:1257/1670 train_time:122638ms step_avg:97.56ms +step:1258/1670 train_time:122735ms step_avg:97.56ms +step:1259/1670 train_time:122832ms step_avg:97.56ms +step:1260/1670 train_time:122928ms step_avg:97.56ms +step:1261/1670 train_time:123028ms step_avg:97.56ms +step:1262/1670 train_time:123129ms step_avg:97.57ms +step:1263/1670 train_time:123229ms step_avg:97.57ms +step:1264/1670 train_time:123328ms step_avg:97.57ms +step:1265/1670 train_time:123425ms step_avg:97.57ms +step:1266/1670 train_time:123522ms step_avg:97.57ms +step:1267/1670 train_time:123619ms step_avg:97.57ms +step:1268/1670 train_time:123716ms step_avg:97.57ms +step:1269/1670 train_time:123813ms step_avg:97.57ms +step:1270/1670 train_time:123909ms step_avg:97.57ms +step:1271/1670 train_time:124008ms step_avg:97.57ms +step:1272/1670 train_time:124107ms step_avg:97.57ms +step:1273/1670 train_time:124205ms step_avg:97.57ms +step:1274/1670 train_time:124479ms step_avg:97.71ms +step:1275/1670 train_time:124674ms step_avg:97.78ms +step:1276/1670 train_time:124770ms step_avg:97.78ms +step:1277/1670 train_time:124866ms step_avg:97.78ms +step:1278/1670 train_time:124963ms step_avg:97.78ms +step:1279/1670 train_time:125059ms step_avg:97.78ms +step:1280/1670 train_time:125157ms step_avg:97.78ms +step:1281/1670 train_time:125254ms step_avg:97.78ms +step:1282/1670 train_time:125350ms step_avg:97.78ms +step:1283/1670 train_time:125448ms step_avg:97.78ms +step:1284/1670 train_time:125547ms step_avg:97.78ms +step:1285/1670 train_time:125648ms step_avg:97.78ms +step:1286/1670 train_time:125748ms step_avg:97.78ms +step:1287/1670 train_time:125845ms step_avg:97.78ms +step:1288/1670 train_time:125942ms step_avg:97.78ms +step:1289/1670 train_time:126039ms step_avg:97.78ms +step:1290/1670 train_time:126136ms step_avg:97.78ms +step:1291/1670 train_time:126233ms step_avg:97.78ms +step:1292/1670 train_time:126330ms step_avg:97.78ms +step:1293/1670 train_time:126427ms step_avg:97.78ms +step:1294/1670 train_time:126525ms step_avg:97.78ms +step:1295/1670 train_time:126625ms step_avg:97.78ms +step:1296/1670 train_time:126725ms step_avg:97.78ms +step:1297/1670 train_time:126823ms step_avg:97.78ms +step:1298/1670 train_time:126920ms step_avg:97.78ms +step:1299/1670 train_time:127017ms step_avg:97.78ms +step:1300/1670 train_time:127114ms step_avg:97.78ms +step:1301/1670 train_time:127211ms step_avg:97.78ms +step:1302/1670 train_time:127308ms step_avg:97.78ms +step:1303/1670 train_time:127405ms step_avg:97.78ms +step:1304/1670 train_time:127502ms step_avg:97.78ms +step:1305/1670 train_time:127600ms step_avg:97.78ms +step:1306/1670 train_time:127700ms step_avg:97.78ms +step:1307/1670 train_time:127799ms step_avg:97.78ms +step:1308/1670 train_time:127896ms step_avg:97.78ms +step:1309/1670 train_time:127994ms step_avg:97.78ms +step:1310/1670 train_time:128093ms step_avg:97.78ms +step:1311/1670 train_time:128190ms step_avg:97.78ms +step:1312/1670 train_time:128287ms step_avg:97.78ms +step:1313/1670 train_time:128384ms step_avg:97.78ms +step:1314/1670 train_time:128482ms step_avg:97.78ms +step:1315/1670 train_time:128580ms step_avg:97.78ms +step:1316/1670 train_time:128678ms step_avg:97.78ms +step:1317/1670 train_time:128777ms step_avg:97.78ms +step:1318/1670 train_time:128875ms step_avg:97.78ms +step:1319/1670 train_time:128973ms step_avg:97.78ms +step:1320/1670 train_time:129071ms step_avg:97.78ms +step:1321/1670 train_time:129170ms step_avg:97.78ms +step:1322/1670 train_time:129266ms step_avg:97.78ms +step:1323/1670 train_time:129365ms step_avg:97.78ms +step:1324/1670 train_time:129464ms step_avg:97.78ms +step:1325/1670 train_time:129561ms step_avg:97.78ms +step:1326/1670 train_time:129658ms step_avg:97.78ms +step:1327/1670 train_time:129756ms step_avg:97.78ms +step:1328/1670 train_time:129854ms step_avg:97.78ms +step:1329/1670 train_time:129951ms step_avg:97.78ms +step:1330/1670 train_time:130049ms step_avg:97.78ms +step:1331/1670 train_time:130146ms step_avg:97.78ms +step:1332/1670 train_time:130243ms step_avg:97.78ms +step:1333/1670 train_time:130340ms step_avg:97.78ms +step:1334/1670 train_time:130438ms step_avg:97.78ms +step:1335/1670 train_time:130536ms step_avg:97.78ms +step:1336/1670 train_time:130634ms step_avg:97.78ms +step:1337/1670 train_time:130733ms step_avg:97.78ms +step:1338/1670 train_time:130832ms step_avg:97.78ms +step:1339/1670 train_time:130930ms step_avg:97.78ms +step:1340/1670 train_time:131029ms step_avg:97.78ms +step:1341/1670 train_time:131127ms step_avg:97.78ms +step:1342/1670 train_time:131225ms step_avg:97.78ms +step:1343/1670 train_time:131323ms step_avg:97.78ms +step:1344/1670 train_time:131421ms step_avg:97.78ms +step:1345/1670 train_time:131518ms step_avg:97.78ms +step:1346/1670 train_time:131615ms step_avg:97.78ms +step:1347/1670 train_time:131714ms step_avg:97.78ms +step:1348/1670 train_time:131813ms step_avg:97.78ms +step:1349/1670 train_time:131911ms step_avg:97.78ms +step:1350/1670 train_time:132009ms step_avg:97.78ms +step:1351/1670 train_time:132107ms step_avg:97.78ms +step:1352/1670 train_time:132204ms step_avg:97.78ms +step:1353/1670 train_time:132302ms step_avg:97.78ms +step:1354/1670 train_time:132400ms step_avg:97.78ms +step:1355/1670 train_time:132497ms step_avg:97.78ms +step:1356/1670 train_time:132594ms step_avg:97.78ms +step:1357/1670 train_time:132692ms step_avg:97.78ms +step:1358/1670 train_time:132791ms step_avg:97.78ms +step:1359/1670 train_time:132889ms step_avg:97.78ms +step:1360/1670 train_time:132987ms step_avg:97.78ms +step:1361/1670 train_time:133084ms step_avg:97.78ms +step:1362/1670 train_time:133181ms step_avg:97.78ms +step:1363/1670 train_time:133280ms step_avg:97.78ms +step:1364/1670 train_time:133377ms step_avg:97.78ms +step:1365/1670 train_time:133475ms step_avg:97.78ms +step:1366/1670 train_time:133573ms step_avg:97.78ms +step:1367/1670 train_time:133670ms step_avg:97.78ms +step:1368/1670 train_time:133767ms step_avg:97.78ms +step:1369/1670 train_time:133865ms step_avg:97.78ms +step:1370/1670 train_time:133963ms step_avg:97.78ms +step:1371/1670 train_time:134062ms step_avg:97.78ms +step:1372/1670 train_time:134159ms step_avg:97.78ms +step:1373/1670 train_time:134257ms step_avg:97.78ms +step:1374/1670 train_time:134354ms step_avg:97.78ms +step:1375/1670 train_time:134453ms step_avg:97.78ms +step:1375/1670 val_loss:3.3452 train_time:134549ms step_avg:97.85ms +step:1376/1670 train_time:134574ms step_avg:97.80ms +step:1377/1670 train_time:134656ms step_avg:97.79ms +step:1378/1670 train_time:134756ms step_avg:97.79ms +step:1379/1670 train_time:134853ms step_avg:97.79ms +step:1380/1670 train_time:134950ms step_avg:97.79ms +step:1381/1670 train_time:135047ms step_avg:97.79ms +step:1382/1670 train_time:135144ms step_avg:97.79ms +step:1383/1670 train_time:135241ms step_avg:97.79ms +step:1384/1670 train_time:135339ms step_avg:97.79ms +step:1385/1670 train_time:135437ms step_avg:97.79ms +step:1386/1670 train_time:135536ms step_avg:97.79ms +step:1387/1670 train_time:135636ms step_avg:97.79ms +step:1388/1670 train_time:135735ms step_avg:97.79ms +step:1389/1670 train_time:135833ms step_avg:97.79ms +step:1390/1670 train_time:135930ms step_avg:97.79ms +step:1391/1670 train_time:136028ms step_avg:97.79ms +step:1392/1670 train_time:136125ms step_avg:97.79ms +step:1393/1670 train_time:136222ms step_avg:97.79ms +step:1394/1670 train_time:136319ms step_avg:97.79ms +step:1395/1670 train_time:136416ms step_avg:97.79ms +step:1396/1670 train_time:136514ms step_avg:97.79ms +step:1397/1670 train_time:136613ms step_avg:97.79ms +step:1398/1670 train_time:136712ms step_avg:97.79ms +step:1399/1670 train_time:136810ms step_avg:97.79ms +step:1400/1670 train_time:136908ms step_avg:97.79ms +step:1401/1670 train_time:137007ms step_avg:97.79ms +step:1402/1670 train_time:137104ms step_avg:97.79ms +step:1403/1670 train_time:137201ms step_avg:97.79ms +step:1404/1670 train_time:137298ms step_avg:97.79ms +step:1405/1670 train_time:137396ms step_avg:97.79ms +step:1406/1670 train_time:137493ms step_avg:97.79ms +step:1407/1670 train_time:137591ms step_avg:97.79ms +step:1408/1670 train_time:137690ms step_avg:97.79ms +step:1409/1670 train_time:137789ms step_avg:97.79ms +step:1410/1670 train_time:137887ms step_avg:97.79ms +step:1411/1670 train_time:137985ms step_avg:97.79ms +step:1412/1670 train_time:138082ms step_avg:97.79ms +step:1413/1670 train_time:138180ms step_avg:97.79ms +step:1414/1670 train_time:138278ms step_avg:97.79ms +step:1415/1670 train_time:138375ms step_avg:97.79ms +step:1416/1670 train_time:138472ms step_avg:97.79ms +step:1417/1670 train_time:138571ms step_avg:97.79ms +step:1418/1670 train_time:138668ms step_avg:97.79ms +step:1419/1670 train_time:138767ms step_avg:97.79ms +step:1420/1670 train_time:138866ms step_avg:97.79ms +step:1421/1670 train_time:138964ms step_avg:97.79ms +step:1422/1670 train_time:139063ms step_avg:97.79ms +step:1423/1670 train_time:139160ms step_avg:97.79ms +step:1424/1670 train_time:139258ms step_avg:97.79ms +step:1425/1670 train_time:139356ms step_avg:97.79ms +step:1426/1670 train_time:139453ms step_avg:97.79ms +step:1427/1670 train_time:139551ms step_avg:97.79ms +step:1428/1670 train_time:139649ms step_avg:97.79ms +step:1429/1670 train_time:139747ms step_avg:97.79ms +step:1430/1670 train_time:139846ms step_avg:97.79ms +step:1431/1670 train_time:139945ms step_avg:97.79ms +step:1432/1670 train_time:140042ms step_avg:97.80ms +step:1433/1670 train_time:140140ms step_avg:97.79ms +step:1434/1670 train_time:140239ms step_avg:97.80ms +step:1435/1670 train_time:140336ms step_avg:97.79ms +step:1436/1670 train_time:140433ms step_avg:97.79ms +step:1437/1670 train_time:140531ms step_avg:97.79ms +step:1438/1670 train_time:140628ms step_avg:97.79ms +step:1439/1670 train_time:140727ms step_avg:97.80ms +step:1440/1670 train_time:140825ms step_avg:97.80ms +step:1441/1670 train_time:140923ms step_avg:97.80ms +step:1442/1670 train_time:141021ms step_avg:97.80ms +step:1443/1670 train_time:141118ms step_avg:97.79ms +step:1444/1670 train_time:141216ms step_avg:97.79ms +step:1445/1670 train_time:141313ms step_avg:97.79ms +step:1446/1670 train_time:141412ms step_avg:97.79ms +step:1447/1670 train_time:141509ms step_avg:97.80ms +step:1448/1670 train_time:141607ms step_avg:97.80ms +step:1449/1670 train_time:141705ms step_avg:97.80ms +step:1450/1670 train_time:141803ms step_avg:97.80ms +step:1451/1670 train_time:141900ms step_avg:97.79ms +step:1452/1670 train_time:141998ms step_avg:97.79ms +step:1453/1670 train_time:142096ms step_avg:97.79ms +step:1454/1670 train_time:142194ms step_avg:97.80ms +step:1455/1670 train_time:142292ms step_avg:97.79ms +step:1456/1670 train_time:142389ms step_avg:97.79ms +step:1457/1670 train_time:142488ms step_avg:97.80ms +step:1458/1670 train_time:142586ms step_avg:97.80ms +step:1459/1670 train_time:142687ms step_avg:97.80ms +step:1460/1670 train_time:142784ms step_avg:97.80ms +step:1461/1670 train_time:142883ms step_avg:97.80ms +step:1462/1670 train_time:142982ms step_avg:97.80ms +step:1463/1670 train_time:143080ms step_avg:97.80ms +step:1464/1670 train_time:143179ms step_avg:97.80ms +step:1465/1670 train_time:143276ms step_avg:97.80ms +step:1466/1670 train_time:143375ms step_avg:97.80ms +step:1467/1670 train_time:143473ms step_avg:97.80ms +step:1468/1670 train_time:143570ms step_avg:97.80ms +step:1469/1670 train_time:143668ms step_avg:97.80ms +step:1470/1670 train_time:143766ms step_avg:97.80ms +step:1471/1670 train_time:143865ms step_avg:97.80ms +step:1472/1670 train_time:143963ms step_avg:97.80ms +step:1473/1670 train_time:144061ms step_avg:97.80ms +step:1474/1670 train_time:144160ms step_avg:97.80ms +step:1475/1670 train_time:144258ms step_avg:97.80ms +step:1476/1670 train_time:144356ms step_avg:97.80ms +step:1477/1670 train_time:144453ms step_avg:97.80ms +step:1478/1670 train_time:144550ms step_avg:97.80ms +step:1479/1670 train_time:144648ms step_avg:97.80ms +step:1480/1670 train_time:144745ms step_avg:97.80ms +step:1481/1670 train_time:144843ms step_avg:97.80ms +step:1482/1670 train_time:144941ms step_avg:97.80ms +step:1483/1670 train_time:145039ms step_avg:97.80ms +step:1484/1670 train_time:145136ms step_avg:97.80ms +step:1485/1670 train_time:145417ms step_avg:97.92ms +step:1486/1670 train_time:145490ms step_avg:97.91ms +step:1487/1670 train_time:145587ms step_avg:97.91ms +step:1488/1670 train_time:145684ms step_avg:97.91ms +step:1489/1670 train_time:145781ms step_avg:97.91ms +step:1490/1670 train_time:145877ms step_avg:97.90ms +step:1491/1670 train_time:145974ms step_avg:97.90ms +step:1492/1670 train_time:146070ms step_avg:97.90ms +step:1493/1670 train_time:146166ms step_avg:97.90ms +step:1494/1670 train_time:146264ms step_avg:97.90ms +step:1495/1670 train_time:146370ms step_avg:97.91ms +step:1496/1670 train_time:146469ms step_avg:97.91ms +step:1497/1670 train_time:146568ms step_avg:97.91ms +step:1498/1670 train_time:146665ms step_avg:97.91ms +step:1499/1670 train_time:146762ms step_avg:97.91ms +step:1500/1670 train_time:146859ms step_avg:97.91ms +step:1500/1670 val_loss:3.3124 train_time:146954ms step_avg:97.97ms +step:1501/1670 train_time:146978ms step_avg:97.92ms +step:1502/1670 train_time:147059ms step_avg:97.91ms +step:1503/1670 train_time:147158ms step_avg:97.91ms +step:1504/1670 train_time:147255ms step_avg:97.91ms +step:1505/1670 train_time:147353ms step_avg:97.91ms +step:1506/1670 train_time:147451ms step_avg:97.91ms +step:1507/1670 train_time:147549ms step_avg:97.91ms +step:1508/1670 train_time:147646ms step_avg:97.91ms +step:1509/1670 train_time:147742ms step_avg:97.91ms +step:1510/1670 train_time:147839ms step_avg:97.91ms +step:1511/1670 train_time:147939ms step_avg:97.91ms +step:1512/1670 train_time:148038ms step_avg:97.91ms +step:1513/1670 train_time:148137ms step_avg:97.91ms +step:1514/1670 train_time:148235ms step_avg:97.91ms +step:1515/1670 train_time:148333ms step_avg:97.91ms +step:1516/1670 train_time:148431ms step_avg:97.91ms +step:1517/1670 train_time:148528ms step_avg:97.91ms +step:1518/1670 train_time:148624ms step_avg:97.91ms +step:1519/1670 train_time:148721ms step_avg:97.91ms +step:1520/1670 train_time:148818ms step_avg:97.91ms +step:1521/1670 train_time:148917ms step_avg:97.91ms +step:1522/1670 train_time:149017ms step_avg:97.91ms +step:1523/1670 train_time:149116ms step_avg:97.91ms +step:1524/1670 train_time:149214ms step_avg:97.91ms +step:1525/1670 train_time:149311ms step_avg:97.91ms +step:1526/1670 train_time:149409ms step_avg:97.91ms +step:1527/1670 train_time:149508ms step_avg:97.91ms +step:1528/1670 train_time:149605ms step_avg:97.91ms +step:1529/1670 train_time:149702ms step_avg:97.91ms +step:1530/1670 train_time:149799ms step_avg:97.91ms +step:1531/1670 train_time:149896ms step_avg:97.91ms +step:1532/1670 train_time:149995ms step_avg:97.91ms +step:1533/1670 train_time:150095ms step_avg:97.91ms +step:1534/1670 train_time:150194ms step_avg:97.91ms +step:1535/1670 train_time:150292ms step_avg:97.91ms +step:1536/1670 train_time:150391ms step_avg:97.91ms +step:1537/1670 train_time:150489ms step_avg:97.91ms +step:1538/1670 train_time:150587ms step_avg:97.91ms +step:1539/1670 train_time:150685ms step_avg:97.91ms +step:1540/1670 train_time:150782ms step_avg:97.91ms +step:1541/1670 train_time:150879ms step_avg:97.91ms +step:1542/1670 train_time:150978ms step_avg:97.91ms +step:1543/1670 train_time:151075ms step_avg:97.91ms +step:1544/1670 train_time:151173ms step_avg:97.91ms +step:1545/1670 train_time:151272ms step_avg:97.91ms +step:1546/1670 train_time:151368ms step_avg:97.91ms +step:1547/1670 train_time:151466ms step_avg:97.91ms +step:1548/1670 train_time:151565ms step_avg:97.91ms +step:1549/1670 train_time:151663ms step_avg:97.91ms +step:1550/1670 train_time:151760ms step_avg:97.91ms +step:1551/1670 train_time:151857ms step_avg:97.91ms +step:1552/1670 train_time:151955ms step_avg:97.91ms +step:1553/1670 train_time:152053ms step_avg:97.91ms +step:1554/1670 train_time:152152ms step_avg:97.91ms +step:1555/1670 train_time:152250ms step_avg:97.91ms +step:1556/1670 train_time:152348ms step_avg:97.91ms +step:1557/1670 train_time:152445ms step_avg:97.91ms +step:1558/1670 train_time:152544ms step_avg:97.91ms +step:1559/1670 train_time:152641ms step_avg:97.91ms +step:1560/1670 train_time:152739ms step_avg:97.91ms +step:1561/1670 train_time:152837ms step_avg:97.91ms +step:1562/1670 train_time:152935ms step_avg:97.91ms +step:1563/1670 train_time:153033ms step_avg:97.91ms +step:1564/1670 train_time:153132ms step_avg:97.91ms +step:1565/1670 train_time:153230ms step_avg:97.91ms +step:1566/1670 train_time:153327ms step_avg:97.91ms +step:1567/1670 train_time:153424ms step_avg:97.91ms +step:1568/1670 train_time:153522ms step_avg:97.91ms +step:1569/1670 train_time:153620ms step_avg:97.91ms +step:1570/1670 train_time:153717ms step_avg:97.91ms +step:1571/1670 train_time:153814ms step_avg:97.91ms +step:1572/1670 train_time:153913ms step_avg:97.91ms +step:1573/1670 train_time:154011ms step_avg:97.91ms +step:1574/1670 train_time:154108ms step_avg:97.91ms +step:1575/1670 train_time:154206ms step_avg:97.91ms +step:1576/1670 train_time:154303ms step_avg:97.91ms +step:1577/1670 train_time:154400ms step_avg:97.91ms +step:1578/1670 train_time:154498ms step_avg:97.91ms +step:1579/1670 train_time:154597ms step_avg:97.91ms +step:1580/1670 train_time:154696ms step_avg:97.91ms +step:1581/1670 train_time:154794ms step_avg:97.91ms +step:1582/1670 train_time:154893ms step_avg:97.91ms +step:1583/1670 train_time:154991ms step_avg:97.91ms +step:1584/1670 train_time:155089ms step_avg:97.91ms +step:1585/1670 train_time:155186ms step_avg:97.91ms +step:1586/1670 train_time:155284ms step_avg:97.91ms +step:1587/1670 train_time:155381ms step_avg:97.91ms +step:1588/1670 train_time:155479ms step_avg:97.91ms +step:1589/1670 train_time:155576ms step_avg:97.91ms +step:1590/1670 train_time:155674ms step_avg:97.91ms +step:1591/1670 train_time:155773ms step_avg:97.91ms +step:1592/1670 train_time:155871ms step_avg:97.91ms +step:1593/1670 train_time:155969ms step_avg:97.91ms +step:1594/1670 train_time:156067ms step_avg:97.91ms +step:1595/1670 train_time:156164ms step_avg:97.91ms +step:1596/1670 train_time:156260ms step_avg:97.91ms +step:1597/1670 train_time:156358ms step_avg:97.91ms +step:1598/1670 train_time:156456ms step_avg:97.91ms +step:1599/1670 train_time:156553ms step_avg:97.91ms +step:1600/1670 train_time:156652ms step_avg:97.91ms +step:1601/1670 train_time:156751ms step_avg:97.91ms +step:1602/1670 train_time:156849ms step_avg:97.91ms +step:1603/1670 train_time:156948ms step_avg:97.91ms +step:1604/1670 train_time:157046ms step_avg:97.91ms +step:1605/1670 train_time:157143ms step_avg:97.91ms +step:1606/1670 train_time:157241ms step_avg:97.91ms +step:1607/1670 train_time:157338ms step_avg:97.91ms +step:1608/1670 train_time:157435ms step_avg:97.91ms +step:1609/1670 train_time:157533ms step_avg:97.91ms +step:1610/1670 train_time:157631ms step_avg:97.91ms +step:1611/1670 train_time:157729ms step_avg:97.91ms +step:1612/1670 train_time:157827ms step_avg:97.91ms +step:1613/1670 train_time:157925ms step_avg:97.91ms +step:1614/1670 train_time:158022ms step_avg:97.91ms +step:1615/1670 train_time:158120ms step_avg:97.91ms +step:1616/1670 train_time:158217ms step_avg:97.91ms +step:1617/1670 train_time:158315ms step_avg:97.91ms +step:1618/1670 train_time:158414ms step_avg:97.91ms +step:1619/1670 train_time:158512ms step_avg:97.91ms +step:1620/1670 train_time:158611ms step_avg:97.91ms +step:1621/1670 train_time:158708ms step_avg:97.91ms +step:1622/1670 train_time:158806ms step_avg:97.91ms +step:1623/1670 train_time:158903ms step_avg:97.91ms +step:1624/1670 train_time:159001ms step_avg:97.91ms +step:1625/1670 train_time:159098ms step_avg:97.91ms +step:1625/1670 val_loss:3.2859 train_time:159195ms step_avg:97.97ms +step:1626/1670 train_time:159218ms step_avg:97.92ms +step:1627/1670 train_time:159303ms step_avg:97.91ms +step:1628/1670 train_time:159404ms step_avg:97.91ms +step:1629/1670 train_time:159502ms step_avg:97.91ms +step:1630/1670 train_time:159599ms step_avg:97.91ms +step:1631/1670 train_time:159695ms step_avg:97.91ms +step:1632/1670 train_time:159792ms step_avg:97.91ms +step:1633/1670 train_time:159889ms step_avg:97.91ms +step:1634/1670 train_time:159986ms step_avg:97.91ms +step:1635/1670 train_time:160083ms step_avg:97.91ms +step:1636/1670 train_time:160182ms step_avg:97.91ms +step:1637/1670 train_time:160285ms step_avg:97.91ms +step:1638/1670 train_time:160386ms step_avg:97.92ms +step:1639/1670 train_time:160484ms step_avg:97.92ms +step:1640/1670 train_time:160583ms step_avg:97.92ms +step:1641/1670 train_time:160681ms step_avg:97.92ms +step:1642/1670 train_time:160779ms step_avg:97.92ms +step:1643/1670 train_time:160877ms step_avg:97.92ms +step:1644/1670 train_time:160973ms step_avg:97.92ms +step:1645/1670 train_time:161070ms step_avg:97.92ms +step:1646/1670 train_time:161168ms step_avg:97.91ms +step:1647/1670 train_time:161267ms step_avg:97.92ms +step:1648/1670 train_time:161368ms step_avg:97.92ms +step:1649/1670 train_time:161466ms step_avg:97.92ms +step:1650/1670 train_time:161564ms step_avg:97.92ms +step:1651/1670 train_time:161662ms step_avg:97.92ms +step:1652/1670 train_time:161761ms step_avg:97.92ms +step:1653/1670 train_time:161858ms step_avg:97.92ms +step:1654/1670 train_time:161954ms step_avg:97.92ms +step:1655/1670 train_time:162051ms step_avg:97.92ms +step:1656/1670 train_time:162150ms step_avg:97.92ms +step:1657/1670 train_time:162248ms step_avg:97.92ms +step:1658/1670 train_time:162345ms step_avg:97.92ms +step:1659/1670 train_time:162444ms step_avg:97.92ms +step:1660/1670 train_time:162543ms step_avg:97.92ms +step:1661/1670 train_time:162641ms step_avg:97.92ms +step:1662/1670 train_time:162738ms step_avg:97.92ms +step:1663/1670 train_time:162836ms step_avg:97.92ms +step:1664/1670 train_time:162933ms step_avg:97.92ms +step:1665/1670 train_time:163030ms step_avg:97.92ms +step:1666/1670 train_time:163128ms step_avg:97.92ms +step:1667/1670 train_time:163226ms step_avg:97.92ms +step:1668/1670 train_time:163325ms step_avg:97.92ms +step:1669/1670 train_time:163423ms step_avg:97.92ms +step:1670/1670 train_time:163522ms step_avg:97.92ms +step:1670/1670 val_loss:3.2780 train_time:163618ms step_avg:97.97ms +peak memory allocated: 34000 MiB reserved: 49576 MiB diff --git a/records/090325_FA3/media/attn_speed_vs_batch_s1024_ws384.png b/records/090325_FA3/media/attn_speed_vs_batch_s1024_ws384.png new file mode 100644 index 000000000..eb60b4b85 Binary files /dev/null and b/records/090325_FA3/media/attn_speed_vs_batch_s1024_ws384.png differ diff --git a/train_gpt.py b/train_gpt.py index bbb431bed..8c5ea5788 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -20,7 +20,7 @@ import numpy as np import triton import triton.language as tl -from flash_attn_interface import flash_attn_func +from flash_attn_interface import flash_attn_varlen_func import torch._dynamo as dynamo dynamo.config.recompile_limit = 64 @@ -600,8 +600,10 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): self.attn_gate = CastedLinear(self.attn_gate_dim, num_heads) self.attn_gate.weight.detach().zero_() - def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, seqlens: Tensor, bm_size: int): B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "varlen sequences requires B == 1" + assert T % 16 == 0 q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 @@ -611,7 +613,11 @@ def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, bm_size: int): else: # skip mid-layers token value embeddings by @YouJiacheng v = lambdas[0] * v - y = flash_attn_func(q, k, v, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) # use flash_attn over flex_attn @varunneal + max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size)) + + # use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng + y = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens, max_seqlen_q=max_len, max_seqlen_k=max_len, + causal=True, softmax_scale=self.attn_scale, window_size=(bm_size, 0)) y = y.view(B, T, self.num_heads, self.head_dim) y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate_dim])).view(B, T, self.num_heads, 1) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side @@ -645,10 +651,11 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None self.mlp = MLP(dim) - def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, bm_size: int): + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, + seqlens: Tensor, bm_size: int): x = lambdas[0] * x + lambdas[1] * x0 if self.attn is not None: - x = x + self.attn(norm(x), ve, sa_lambdas, bm_size) + x = x + self.attn(norm(x), ve, sa_lambdas, seqlens, bm_size) x = x + self.mlp(norm(x)) return x @@ -690,19 +697,19 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.scalars.lr_mul = 5.0 - def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: int): - assert input_seq.ndim == 2 + def forward(self, input_seq: Tensor, target_seq: Tensor, seqlens: Tensor, ws: int): + assert input_seq.ndim == 1 ve = [value_embed(input_seq) for value_embed in self.value_embeds] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] assert len(ve) == len(self.blocks) - long_bm, short_bm = ws_long * args.bandwidth, ws_short * args.bandwidth + long_bm, short_bm = ws * args.block_size, (ws // 2) * args.block_size bm_sizes = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] assert len(bm_sizes) == len(self.blocks) - x = x0 = norm(self.embed(input_seq)) # use of norm here by @Grad62304977 + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 # U-net design by @brendanh0gan skip_connections = [] @@ -715,7 +722,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: for i in range(len(self.blocks)): if i >= n: x = x + skip_weights[i - n] * skip_connections.pop() - x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], bm_sizes[i]) + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], seqlens, bm_sizes[i]) if i < n: skip_connections.append(x) @@ -723,8 +730,7 @@ def forward(self, input_seq: Tensor, target_seq: Tensor, ws_long: int, ws_short: logits = self.lm_head(x).float() # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits / 7.5) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq.view(-1), - reduction="sum" if self.training else "mean") + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") return loss # ----------------------------------------------------------------------------- @@ -742,103 +748,108 @@ def _load_data_shard(file: Path): assert nbytes == 2 * num_tokens, "number of tokens read does not match header" return tokens -class EOSBatchFinder: +BOS_ID = 50256 + +class BOSFinder: # Helper for getting sequences that start at the beginning of documents by @varunneal based on work by @classiclarryd - def __init__(self, tokens: Tensor, world_size: int = 1, eos_id: int = 50256): - # Precompute EOS positions once per shard - self.eos_idx = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() - self.i = 0 # pointer into eos_idx (start EOS for next step) - self.pos = 0 # logical stream position within this shard + def __init__(self, tokens: Tensor, world_size: int = 1): + # Precompute BOS positions once per shard + self.size = tokens.numel() + self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() + self.i = 0 self.world_size = world_size - def seek(self, pos: int): - # Set pointer to the first EOS >= pos - self.i = np.searchsorted(self.eos_idx, pos) - if self.i >= len(self.eos_idx): - raise StopIteration("Seek past last EOS.") - self.pos = pos - def next_batch(self, batch_size_local: int, seq_len: int): - n = len(self.eos_idx) - if self.i >= n: - raise StopIteration("No more EOS in this shard.") + + def next_batch(self, num_tokens_local: int, max_seq_len: int): + n = len(self.bos_idx) starts = [[] for _ in range(self.world_size)] + ends = [[] for _ in range(self.world_size)] + idx = self.i - cur = self.eos_idx[idx] # EOS that ends the "previous" document; next doc starts at cur+1 for r in range(self.world_size): - for _ in range(batch_size_local): - start = cur + 1 - target = start + seq_len # need seq_len tokens before next EOS - j = np.searchsorted(self.eos_idx, target) - if j >= n: - raise StopIteration("Insufficient EOS ahead; hit tail of shard.") - starts[r].append(start) - idx = j - cur = self.eos_idx[idx] # next seq must also start at a new doc - advance = self.eos_idx[idx] - self.pos # move stream to the last end - self.pos += advance + cur_len = 0 + while cur_len <= num_tokens_local: + if idx >= n: + raise StopIteration(f"Insufficient BOS ahead of position {cur}; hit tail of shard.") + cur = self.bos_idx[idx] + starts[r].append(cur) + end = min(self.bos_idx[idx + 1] if idx + 1 < n else self.size, + cur + max_seq_len, + cur + num_tokens_local - cur_len + 1) + ends[r].append(end) + cur_len += end - cur + idx += 1 + + assert cur_len == num_tokens_local + 1 self.i = idx - return starts, advance + return starts, ends -def distributed_data_generator(filename_pattern: str, batch_size: int, seq_len: int, align_to_bos: bool = True): - # align_to_bos: each sequence begins with Beginning of Sequence token and sequences don't overlap +def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): + # align_to_bos: each sequence begins with Beginning of Sequence token, sequences truncated to max_seq_len rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 - assert batch_size % world_size == 0, "Batch size must be divisible by world size" + assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" + num_tokens = num_tokens // grad_accum_steps files = [Path(file) for file in sorted(glob.glob(filename_pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") file_iter = iter(files) # Use itertools.cycle(files) for multi-epoch training - tokens, pos = _load_data_shard(next(file_iter)), 0 - - finder = EOSBatchFinder(tokens, world_size=world_size) if align_to_bos else None - if align_to_bos: finder.seek(pos) + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) if align_to_bos else None + pos = 0 # for unaligned case while True: - batch_size_local = batch_size // world_size - num_tokens_global = batch_size * seq_len - - if not align_to_bos and pos + num_tokens_global + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 + num_tokens_local = num_tokens // world_size + max_num_docs = next_multiple_of_n(num_tokens_local // 300, n=128) # median doc length is ~400 if align_to_bos: try: - batch_starts, batch_span = finder.next_batch(batch_size_local, seq_len) - start_idxs = batch_starts[rank] + seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) + start_idxs, end_idxs = torch.tensor(seq_starts[rank]), torch.tensor(seq_ends[rank]) except StopIteration: # This shard is exhausted, load the next one in the next loop iteration. - tokens, pos = _load_data_shard(next(file_iter)), 0 - finder = EOSBatchFinder(tokens, world_size=world_size) + tokens = _load_data_shard(next(file_iter)) + finder = BOSFinder(tokens, world_size=world_size) continue - bufs = [tokens[s: s + seq_len + 1] for s in start_idxs] - buf = torch.stack(bufs, dim=0) - _inputs = buf[:, :-1] - _targets = buf[:, 1:] + buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) + _inputs = buf[:-1] + _targets = buf[1:] + end_idxs[-1] -= 1 # last document was too long to account for _targets offset + cum_lengths = (end_idxs - start_idxs).cumsum(0) + else: - batch_span = num_tokens_global - start_pos_local = pos + rank * (batch_size_local * seq_len) - end_pos_local = start_pos_local + (batch_size_local * seq_len) + if pos + num_tokens + 1 >= len(tokens): # should not occur for val data + tokens, pos = _load_data_shard(next(file_iter)), 0 + + pos_local = pos + rank * num_tokens_local + buf = tokens[pos_local: pos_local + num_tokens_local + 1] + _inputs = buf[:-1].view(num_tokens_local, ) + _targets = buf[1:].view(num_tokens_local, ) - buf = tokens[start_pos_local: end_pos_local + 1] + cum_lengths = torch.nonzero(_inputs == BOS_ID)[:, 0] + pos += num_tokens - _inputs = buf[:-1].view(batch_size_local, seq_len) - _targets = buf[1:].view(batch_size_local, seq_len) + + _cum_lengths = torch.full((max_num_docs,), num_tokens_local) + _cum_lengths[0] = 0 + _cum_lengths[1:len(cum_lengths) + 1] = cum_lengths new_params = yield ( _inputs.to(device="cuda", dtype=torch.int32, non_blocking=True), - _targets.to(device="cuda", dtype=torch.int64, non_blocking=True) + _targets.to(device="cuda", dtype=torch.int64, non_blocking=True), + _cum_lengths.to(device="cuda", dtype=torch.int32, non_blocking=True) ) - pos += batch_span - if new_params is not None: - # makes it possible for generator to recieve new (batch_size, seq_len) via .send() - new_batch_size, new_seq_len = new_params - assert new_batch_size % world_size == 0, "New batch size must be divisible by world size" - batch_size = new_batch_size - seq_len = new_seq_len + # makes it possible for generator to receive new (num_tokens, max_seq_len, grad_accum_steps) via .send() + new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params + assert new_num_tokens % (world_size * grad_accum_steps) == 0, "Num tokens must be divisible by world size" + num_tokens = new_num_tokens + max_seq_len = new_max_seq_len + grad_accum_steps = new_grad_accum_steps # ----------------------------------------------------------------------------- @@ -850,18 +861,18 @@ class Hyperparameters: train_files: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on val_files: str = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons - train_seq_len: int = 1024 * 2 - train_batch_size: int = 24 * 8 - val_seq_len: int = 4 * 64 * 1024 # Validation will be done with batch size = world_size. + train_batch_size: int = 2048 * 24 * 8 + train_max_seq_len: int = 128 * 16 + val_batch_size: int = 4 * 64 * 1024 * 8 # optimization - num_iterations: int = 1695 # number of iterations to run + num_iterations: int = 1670 # number of iterations to run cooldown_frac: int = 0.45 # fraction of training spent cooling down the learning rate # evaluation and logging run_id: str = str(uuid.uuid4()) val_loss_every: int = 125 # every how many steps to evaluate val loss? 0 for only at the end save_checkpoint: bool = False # attention masking - bandwidth: int = 128 + block_size: int = 128 ws_schedule: tuple = (3, 7, 11) args = Hyperparameters() @@ -915,7 +926,7 @@ def nvidia_smi(): num_layers=12, num_heads=6, model_dim=768, - max_seq_len=max(args.train_seq_len, args.val_seq_len) + max_seq_len=max(args.train_batch_size, args.val_batch_size) // (grad_accum_steps * world_size) ).cuda() for m in model.modules(): if isinstance(m, nn.Embedding): @@ -940,15 +951,20 @@ def nvidia_smi(): group["initial_lr"] = group["lr"] # learning rate schedule: stable then decay -def get_lr_and_ws(step: int): - x = step / (1 + args.num_iterations) # progress in training +def get_lr(step: int): + x = step / args.num_iterations assert 0 <= x < 1 lr = 1.0 if x >= 1 - args.cooldown_frac: w = (1 - x) / args.cooldown_frac lr = w * 1.0 + (1 - w) * 0.1 + return lr + +def get_ws(step: int): + x = step / (1 + args.num_iterations) + assert 0 <= x < 1 ws_idx = int(len(args.ws_schedule) * x) - return lr, args.ws_schedule[ws_idx] + return args.ws_schedule[ws_idx] model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) @@ -957,14 +973,14 @@ def get_lr_and_ws(step: int): ######################################## # Warmup the training kernels, then re-initialize the state so we aren't cheating -warmup_steps = 60 +warmup_steps = 30 initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) for step in range(warmup_steps): - inputs, targets = next(train_loader) - ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up eachZ - model(inputs, targets, ws, ws // 2).backward() + inputs, targets, cum_seqlens = next(train_loader) + ws = args.ws_schedule[step % len(args.ws_schedule)] # each window size is a new graph, need to warm up each + model(inputs, targets, cum_seqlens, ws).backward() for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) @@ -977,7 +993,7 @@ def get_lr_and_ws(step: int): # Training and validation # ######################################## -train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_seq_len) +train_loader = distributed_data_generator(args.train_files, args.train_batch_size, args.train_max_seq_len, grad_accum_steps=grad_accum_steps) training_time_ms = 0 # start the clock torch.cuda.synchronize() @@ -986,7 +1002,7 @@ def get_lr_and_ws(step: int): train_steps = args.num_iterations for step in range(train_steps + 1): last_step = (step == train_steps) - lr, ws = get_lr_and_ws(step) + ws = get_ws(step) # --------------- VALIDATION SECTION ----------------- if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): @@ -994,14 +1010,14 @@ def get_lr_and_ws(step: int): torch.cuda.synchronize() training_time_ms += 1000 * (time.perf_counter() - t0) model.eval() - assert args.val_tokens % (world_size * args.val_seq_len) == 0 - val_steps = args.val_tokens // (world_size * args.val_seq_len) - val_loader = distributed_data_generator(args.val_files, world_size, args.val_seq_len, align_to_bos=False) + assert args.val_tokens % args.val_batch_size == 0 + val_steps = grad_accum_steps * args.val_tokens // args.val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_batch_size, -1, grad_accum_steps=grad_accum_steps, align_to_bos=False) val_loss = 0 with torch.no_grad(): for _ in range(val_steps): - inputs, targets = next(val_loader) - val_loss += model(inputs, targets, ws, ws // 2) + inputs, targets, cum_seqlens = next(val_loader) + val_loss += model(inputs, targets, cum_seqlens, ws) val_loss /= val_steps del val_loader dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) @@ -1021,12 +1037,12 @@ def get_lr_and_ws(step: int): # --------------- TRAINING SECTION ----------------- for _ in range(grad_accum_steps): - inputs, targets = next(train_loader) - model(inputs, targets, ws, ws // 2).backward() + inputs, targets, cum_seqlens = next(train_loader) + model(inputs, targets, cum_seqlens, ws).backward() # set optimization hyperparameters for opt in optimizers: for group in opt.param_groups: - group["lr"] = group["initial_lr"] * lr + group["lr"] = group["initial_lr"] * get_lr(step) for group in optimizer2.param_groups: frac = min(step / 300, 1) # momentum warmup for muon group["momentum"] = (1 - frac) * 0.85 + frac * 0.95