Skip to content

Implement attentive pooling for finetuning #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 111 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
80d3b25
add git ignore
yawenzzzz May 8, 2025
cf541d3
ruff
yawenzzzz May 8, 2025
4c41530
ruff
yawenzzzz May 8, 2025
213903f
add argument description
yawenzzzz May 11, 2025
592caa2
add the reference points
yawenzzzz May 12, 2025
58078eb
add mangrove configs
yawenzzzz May 12, 2025
5eca1d0
add configs for mangrove
yawenzzzz May 13, 2025
a183ccf
add scripts for worldcereal
yawenzzzz Apr 28, 2025
e8ebd56
remove non-cropland classes
yawenzzzz May 13, 2025
8db2eb1
add back the window save
yawenzzzz May 13, 2025
5907d75
update task name
yawenzzzz May 13, 2025
057143e
Merge branch 'master' of github.com:allenai/rslearn_projects
yawenzzzz May 13, 2025
9a49755
update configs to remove wrong sampler
yawenzzzz May 15, 2025
183f3c0
update configs to remove the wrong samplers
yawenzzzz May 15, 2025
89fb8c9
Merge branch 'master' of github.com:allenai/rslearn_projects
yawenzzzz May 15, 2025
02b975d
remove incorrect sampler
yawenzzzz May 15, 2025
96fe57b
modify create windows
yawenzzzz Jun 9, 2025
24cf8da
Merge branch 'master' of github.com:allenai/rslearn_projects
yawenzzzz Jun 9, 2025
79088ef
add docstring
yawenzzzz Jun 12, 2025
94d5e00
Use Beaker queues instead of Google Cloud Pub/Sub for the worker system.
favyen2 May 27, 2025
0349092
fix
favyen2 May 27, 2025
487f7e8
Support specifying retries on the Beaker job.
favyen2 Jun 3, 2025
5cc3ec4
Add time series and multimodal configs for marine_infra and solar_farm.
favyen2 Jun 9, 2025
79c9eb5
add pastis, sentinel1/2 vessels, and sentinel2 vessel attribute
favyen2 Jun 9, 2025
5db2307
some config fixes
favyen2 Jun 10, 2025
e81e6e6
make v2 folder
yawenzzzz Jun 12, 2025
8e7fa21
add initial config
yawenzzzz Jun 12, 2025
2c9cc68
merge conflicts
yawenzzzz Jun 13, 2025
d09f631
add kenya/nandi crop type mapping model configs
yawenzzzz Jun 13, 2025
a8ba43d
udpate readme
yawenzzzz Jun 13, 2025
4dd0268
merge conflicst
yawenzzzz Jun 14, 2025
c24850f
fix ruff
yawenzzzz Jun 14, 2025
7834449
fix ruff
yawenzzzz Jun 14, 2025
75c048f
merge conflicts
yawenzzzz Jun 15, 2025
bbd3c21
ruff
yawenzzzz Jun 15, 2025
86c1162
add the new configs
yawenzzzz Jun 16, 2025
e8b0bb5
update launcher
yawenzzzz Jun 16, 2025
454be14
update readme
yawenzzzz Jun 16, 2025
48e5535
update gitignore
yawenzzzz Jun 16, 2025
b5cfcbc
update readme for mangrove
yawenzzzz Jun 16, 2025
30422d8
update readme
yawenzzzz Jun 17, 2025
0463469
update
yawenzzzz Jun 17, 2025
4b07c8c
add configs
yawenzzzz Jun 18, 2025
e4db56c
format
yawenzzzz Jun 26, 2025
79edb97
fix files
yawenzzzz Jun 27, 2025
37b18fe
update train and val samples
yawenzzzz Jun 27, 2025
dffcc73
upload configs
yawenzzzz Jun 27, 2025
72d117c
fix bug
yawenzzzz Jun 28, 2025
f1d5336
Fix code archive dir, add profiler
ryspark Jun 30, 2025
9744f4f
Update gitignore, typo
ryspark Jun 30, 2025
f2cf76c
Update gitignore, add new dockerfile
ryspark Jul 1, 2025
7dabe0b
Add profiler, allow local finetune
ryspark Jul 1, 2025
8a7af9a
Start peft (alpa)
ryspark Jul 1, 2025
119518b
Update gitignore, allow local eval saving
ryspark Jul 1, 2025
61c8bcc
Download best finetuned models
ryspark Jul 1, 2025
98e2e8d
Delete old dockerfile, remove alpa
ryspark Jul 2, 2025
098278f
Allow sft loading and move alpa to rslearn
ryspark Jul 2, 2025
f45403c
Add apla cfg, bug fixes to script
ryspark Jul 2, 2025
1a888a5
Support pretrained cfg copy
ryspark Jul 2, 2025
8331427
Add cross finetuning script
ryspark Jul 2, 2025
f465765
Fix nandi padding size, update experiments
ryspark Jul 2, 2025
b7b9bcf
Update scripts
ryspark Jul 3, 2025
e8f436d
Start concat datasets implementation
ryspark Jul 3, 2025
758ef46
Manual merge from master
ryspark Jul 3, 2025
48b232f
Add new multidataset configs
ryspark Jul 4, 2025
5c1d793
Fix multi dataset config
ryspark Jul 4, 2025
0ee57b4
Update configs, move multidataset to rslearn
ryspark Jul 4, 2025
aa94069
Clean up cfg, allow max patch limit
ryspark Jul 6, 2025
786267b
Allow batch size override on multidataset
ryspark Jul 7, 2025
1aef684
Exclude project_data from code archive
ryspark Jul 7, 2025
c2ea6fd
Fix docker build on beaker job
ryspark Jul 7, 2025
12db8fb
Update config to allow multi file setups
ryspark Jul 7, 2025
fec1024
Remove redundant multi config key
ryspark Jul 7, 2025
d8188a4
Remove unused config
ryspark Jul 7, 2025
f94567b
Update gitignore for tmp dir
ryspark Jul 8, 2025
6b0e274
Delete beaker_launcher.py
ryspark Jul 8, 2025
904e86d
Add perf benchmark configs
ryspark Jul 8, 2025
e84b827
Add attentive pooling after encoding
ryspark Jul 9, 2025
0ccbe90
Add new attn configs
ryspark Jul 9, 2025
2b3baf4
Update cfgs, specify num_workers
ryspark Jul 9, 2025
e461f3b
Merge branch 'ryanp/singletask' into ryanp/attn
ryspark Jul 9, 2025
a8cd137
Merge branch 'master' into ryanp/singletask
ryspark Jul 9, 2025
57f6bee
Merge branch 'ryanp/singletask' into ryanp/attn
ryspark Jul 9, 2025
95fb0be
Update cfgs for attn benchmarks
ryspark Jul 9, 2025
535232f
Add mha + attentive pool cfgs
ryspark Jul 9, 2025
ffca5e8
Pool over bandset dimension too
ryspark Jul 10, 2025
a8b9494
Freeze for longer with attnpool
ryspark Jul 10, 2025
c6ef4a9
Merge branch 'master' into ryanp/attn
ryspark Jul 10, 2025
bd5a8eb
Merge branch 'master' into ryanp/singletask
ryspark Jul 10, 2025
1961d4e
Testing new dataloader speed
ryspark Jul 10, 2025
df28469
Rename max_num_workers to num_workers
ryspark Jul 11, 2025
abe9a7a
Update configs
ryspark Jul 11, 2025
b247d5f
Merge branch 'ryanp/singletask' into ryanp/attn
ryspark Jul 11, 2025
0b43e77
Merge branch 'master' into ryanp/singletask
ryspark Jul 11, 2025
d670254
Merge branch 'ryanp/singletask' into ryanp/attn
ryspark Jul 11, 2025
00f3fd0
Add docs on multidataset
ryspark Jul 11, 2025
217e479
[Breaking] change rslearn/helios paths in dockerfile
ryspark Jul 11, 2025
64bd9cd
Fix linter issues
ryspark Jul 11, 2025
d178117
Merge branch 'master' into ryanp/singletask
ryspark Jul 11, 2025
c5e689c
Remove lightning auto sampler
ryspark Jul 13, 2025
a2b92e2
Add configs, update str sub
ryspark Jul 14, 2025
dfac90a
Apply pr suggestions
ryspark Jul 14, 2025
3a1020b
Config changes
ryspark Jul 14, 2025
7cb1421
Fix linter issues
ryspark Jul 14, 2025
c2010ed
Update helios README
ryspark Jul 14, 2025
d1fb967
Update scripts, configs
ryspark Jul 15, 2025
3e862dd
Merge branch 'ryanp/singletask' into ryanp/attn
ryspark Jul 15, 2025
ff2da84
Linter, merge fixes
ryspark Jul 15, 2025
154e647
Update medium config
ryspark Jul 15, 2025
72c3661
Linter fix
ryspark Jul 15, 2025
ae3e0ec
No temporal mean in encoder
ryspark Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
lightning_logs
wandb
**/test_data/**/**/*.tif
**/project_data
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
.env
wandb/
project_data
/helios/
/rslearn/
rslp/__pycache__/
**__pycache__
rslp/crop_type_mapping/csv/
rslp/crop_type_mapping/geoparquets/
rslp/mangrove/csv/
log*.txt

# for local finetuning runs
/config.yaml
lightning_logs/
/tmp*/
docker_build
140 changes: 140 additions & 0 deletions data/helios/v2_landsat_vessels/finetune_detector_attn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslp.helios.model.Helios
init_args:
checkpoint_path: "{CHECKPOINT_PATH}"
selector: ["encoder"]
forward_kwargs:
patch_size: {PATCH_SIZE}
decoders:
vessel_detection:
- class_path: rslearn.models.pooling.AttentivePool
init_args:
n_channels: {ENCODER_EMBEDDING_SIZE}
height: 16
width: 16
- class_path: rslearn.models.faster_rcnn.FasterRCNN
init_args:
downsample_factors: [{PATCH_SIZE}]
num_channels: {ENCODER_EMBEDDING_SIZE}
num_classes: 2
anchor_sizes: [[32]]
lr: 0.0001
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624
inputs:
image:
data_type: "raster"
layers: ["landsat"]
bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: INT32
is_target: true
targets:
data_type: "vector"
layers: ["label"]
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
vessel_detection:
class_path: rslearn.train.tasks.detection.DetectionTask
init_args:
property_name: "category"
classes: ["unknown", "vessel"]
box_size: 15
remap_values: [[0, 1], [0, 255]]
exclude_by_center: true
score_threshold: 0.7
enable_map_metric: true
enable_f1_metric: true
f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]]
f1_metric_kwargs:
cmp_mode: "distance"
cmp_threshold: 15
flatten_classes: true
input_mapping:
vessel_detection:
targets: "targets"
batch_size: 8
num_workers: 16
default_config:
transforms:
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image: []
output_selector: landsat
- class_path: rslp.helios.norm.HeliosNormalize
init_args:
config_fname: "/opt/helios/data/norm_configs/computed.json"
band_names:
landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"]
train_config:
patch_size: 128
load_all_patches: true
groups: ["labels_utm"]
tags:
split: "train"
transforms:
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image: []
output_selector: landsat
- class_path: rslp.helios.norm.HeliosNormalize
init_args:
config_fname: "/opt/helios/data/norm_configs/computed.json"
band_names:
landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"]
val_config:
patch_size: 128
load_all_patches: true
groups: ["labels_utm"]
tags:
split: "val"
test_config:
patch_size: 128
load_all_patches: true
groups: ["labels_utm"]
tags:
split: "val"
trainer:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_vessel_detection/mAP
mode: max
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
init_args:
module_selector: ["model", "encoder", 0]
unfreeze_at_epoch: 2
rslp_project: placeholder
rslp_experiment: placeholder
131 changes: 131 additions & 0 deletions data/helios/v2_landsat_vessels/finetune_detector_nofreeze.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslp.helios.model.Helios
init_args:
checkpoint_path: "{CHECKPOINT_PATH}"
selector: ["encoder"]
forward_kwargs:
patch_size: {PATCH_SIZE}
decoders:
vessel_detection:
- class_path: rslearn.models.faster_rcnn.FasterRCNN
init_args:
downsample_factors: [{PATCH_SIZE}]
num_channels: {ENCODER_EMBEDDING_SIZE}
num_classes: 2
anchor_sizes: [[32]]
lr: 0.0001
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624
inputs:
image:
data_type: "raster"
layers: ["landsat"]
bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: INT32
is_target: true
targets:
data_type: "vector"
layers: ["label"]
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
vessel_detection:
class_path: rslearn.train.tasks.detection.DetectionTask
init_args:
property_name: "category"
classes: ["unknown", "vessel"]
box_size: 15
remap_values: [[0, 1], [0, 255]]
exclude_by_center: true
score_threshold: 0.7
enable_map_metric: true
enable_f1_metric: true
f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]]
f1_metric_kwargs:
cmp_mode: "distance"
cmp_threshold: 15
flatten_classes: true
input_mapping:
vessel_detection:
targets: "targets"
batch_size: 8
num_workers: 16
default_config:
transforms:
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image: []
output_selector: landsat
- class_path: rslp.helios.norm.HeliosNormalize
init_args:
config_fname: "/opt/helios/data/norm_configs/computed.json"
band_names:
landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"]
train_config:
patch_size: 128
load_all_patches: true
groups: ["labels_utm"]
tags:
split: "train"
transforms:
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image: []
output_selector: landsat
- class_path: rslp.helios.norm.HeliosNormalize
init_args:
config_fname: "/opt/helios/data/norm_configs/computed.json"
band_names:
landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"]
val_config:
patch_size: 128
load_all_patches: true
groups: ["labels_utm"]
tags:
split: "val"
test_config:
patch_size: 128
load_all_patches: true
groups: ["labels_utm"]
tags:
split: "val"
trainer:
max_epochs: 200
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_vessel_detection/mAP
mode: max
rslp_project: placeholder
rslp_experiment: placeholder
Loading
Loading