Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e9f79b1
Initial sketch of DGN example
rickybalin Nov 10, 2025
e0907bf
Progress adding sampler and forward diffusion process
rickybalin Nov 10, 2025
87385f5
Update dist-gnn model to make the dgn model
rickybalin Nov 11, 2025
3e6401f
Fix batching issue
rickybalin Dec 1, 2025
78abec8
Sketched first part of inference pipeline for dist-dgn model
rickybalin Dec 2, 2025
afb2eda
Add ext_cyl example and progress on debugging dgn model
rickybalin Dec 17, 2025
16ceb50
Fix runtime bug during inference
rickybalin Jan 5, 2026
1aeb6e2
Add batching with respect to the noise step r
rickybalin Jan 8, 2026
35de73c
Add hybrid loss
rickybalin Jan 14, 2026
c0e6de8
Add exponential sampler for diffusion steps
rickybalin Jan 20, 2026
3949f52
Smaller updates to run scripts and post-processing
rickybalin Jan 26, 2026
d58e64b
Fix bug with loss computation, testing new step embeddings
rickybalin Jan 28, 2026
487ec57
Minor change
rickybalin Jan 28, 2026
fc966da
Add SNR loss scaling and option for x0 prediction
rickybalin Feb 10, 2026
7dc1e86
Add activation checkpointing to dist-dgn model
rickybalin Feb 10, 2026
4843de5
Add v prediction for the dist-dgn model
rickybalin Feb 10, 2026
e99e6d6
Fixed some bugs in dist-dgn model
rickybalin Feb 11, 2026
41cfdc9
Small fixes to trainer and postprocessing of dist-dgn model
rickybalin Feb 11, 2026
a332b9a
[WIP] Add conditional node features (d2wall and d2inflow) to gnn plug…
rickybalin Feb 16, 2026
f6966ea
Big fixes for coditional node features
rickybalin Feb 17, 2026
6437e6a
Modify plotting function in postprocess to plot contour instead of sc…
rickybalin Feb 18, 2026
7384800
(WIP) Add consistent MSE loss to dist-dgn model
rickybalin Feb 19, 2026
dfb45d7
(WIP) Debugging consistent MSE loss for dist-dgn model
rickybalin Feb 21, 2026
fbd263d
(WIP) Debugging consistency in dst-dgn model
rickybalin Feb 23, 2026
4fca38e
(WIP) Fix size=1 bug in dist-gnn model and more debugging of dist-dgn…
rickybalin Feb 24, 2026
b9b5178
Add coarse mesh for ext_cyl_dgn case
rickybalin Mar 9, 2026
2c412dc
Add coarse mesh for ext_cyl_dgn example and some minor fixes to dist-…
rickybalin Mar 10, 2026
65b079a
Add y coordinate as possible conditional feature
rickybalin Mar 10, 2026
1d3d9ff
Small updates
rickybalin Mar 12, 2026
d02d05f
Fix random seed issues
rickybalin Mar 16, 2026
ca6af9c
Remove some comments
rickybalin Mar 16, 2026
bfd572e
(WIP) Fix consistency in field_r
rickybalin Mar 16, 2026
9226cb6
Update nrsrun_aurora to run with a single rank for now
rickybalin Apr 13, 2026
5d6e292
Merge with main
rickybalin May 21, 2026
f88b3a5
Bring in consistency fixes from dist-gnn model
rickybalin May 21, 2026
711d1bf
Add halo sync for noise added to fields
rickybalin May 22, 2026
5c6b9de
Merge with remote
rickybalin May 22, 2026
9d9505a
Format code with ruff
rickybalin May 22, 2026
5bd04e1
Obtained consistency for dist-dgn model
rickybalin May 28, 2026
7e7679b
Format code with ruff
rickybalin May 28, 2026
e9808e6
Merge branch 'main' into dgn
rickybalin Jun 3, 2026
7145591
Format code with ruff
rickybalin Jun 3, 2026
f06c53a
Add elementwise transformer to DGN diffusion model (#75)
rickybalin Jun 4, 2026
e668312
Format code with ruff
rickybalin Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions 3rd_party/gnn/dist-dgn/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import sys
from typing import Optional, Union, Tuple
import logging
from omegaconf import DictConfig
import numpy as np
from time import sleep, perf_counter

log = logging.getLogger(__name__)


class OnlineClient:
"""Class for the online training client"""

def __init__(self) -> None:
self.client = None
95 changes: 95 additions & 0 deletions 3rd_party/gnn/dist-dgn/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# @package _global_
verbose: False
timers: False
postprocess: False
statistics: False
seed: 12
backend: null
num_threads: 1
logfreq: 10
ckptfreq: 500
batch_size: 1
val_batch_size: 1
precision: fp32
fp16_allreduce: False
restart: False
master_addr: none
master_port: 2345
device_skip: 0 # temporary workaround to skip GPU on a node

# DGN specific parameters
prediction_type: "epsilon" # {epsilon, x0, v}
num_diffusion_steps: 100
emb_width: 128
diffusion_process_schedule: "linear" # {linear, cosine}
diffusion_step_sampler: "uniform" # {uniform, adaptive_exponential}
num_gen_samples: 1
input_node_features: 3
learnable_variance: False
loss_weighting: "uniform" # {uniform, min_snr}
min_snr_gamma: 5.0 # gamma for min-SNR weighting (only used when loss_weighting=min_snr)
activation_checkpointing: False
cond_node_features: False

# Modeling task
model_task: "train"
consistency: True

# learning rate schedule and training steps
phase1_steps: 100
phase2_steps: 0
phase3_steps: 0
lr_phase12: 0.0001
lr_phase23: 0.0001

# model arch properties
model_name: "gnn" # {gnn, graph_transformer}
mlp_hidden_channels: 32
n_mlp_hidden_layers: 5
n_messagePassing_layers: 4
n_transformer_layers: 4
num_heads: 4
layer_norm: False
dropout_rate: 0.0

# Halo swap mode
halo_swap_mode : none

# transform directions (enforces periodicity in graph)
transform_x: false
transform_y: false
transform_z: false

# Correctness validation
target_loss: 0

# plotting the connecivity (can take some time)
plot_connectivity : False

# specify path to gnn_outputs here (case directory)
gnn_outputs_path: ${work_dir}/gnn_outputs_poly_7/
traj_data_path: ${work_dir}/traj_poly_1/

# Online training
online: False

hydra:
job:
chdir: false

#defaults:
# - override hydra/hydra_logging: colorlog
# - override hydra/job_logging: colorlog

# path to original working directory
# hydra hijacks working directory by changing it to the new log directory
# so its useful to have this path as a special variable
# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
work_dir: ${hydra:runtime.cwd}
# path to folder with data
data_dir: ${work_dir}/datasets/
# path to folder for checkpointing
ckpt_dir: ${work_dir}/ckpt/
# path to saved model directory
model_dir: ${work_dir}/saved_models/

Loading