diff --git a/.dockerignore b/.dockerignore index 835a0a55..bc10c831 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,3 +5,4 @@ lightning_logs wandb **/test_data/**/**/*.tif +**/project_data diff --git a/.gitignore b/.gitignore index d058727b..485b206f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,17 @@ .env wandb/ +project_data +/helios/ +/rslearn/ rslp/__pycache__/ **__pycache__ +rslp/crop_type_mapping/csv/ +rslp/crop_type_mapping/geoparquets/ +rslp/mangrove/csv/ +log*.txt + +# for local finetuning runs +/config.yaml +lightning_logs/ +/tmp*/ +docker_build diff --git a/data/helios/v2_landsat_vessels/finetune_detector_attn.yaml b/data/helios/v2_landsat_vessels/finetune_detector_attn.yaml new file mode 100644 index 00000000..0ded7669 --- /dev/null +++ b/data/helios/v2_landsat_vessels/finetune_detector_attn.yaml @@ -0,0 +1,140 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + vessel_detection: + - class_path: rslearn.models.pooling.AttentivePool + init_args: + n_channels: {ENCODER_EMBEDDING_SIZE} + height: 16 + width: 16 + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [{PATCH_SIZE}] + num_channels: {ENCODER_EMBEDDING_SIZE} + num_classes: 2 + anchor_sizes: [[32]] + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + inputs: + image: + data_type: "raster" + layers: ["landsat"] + bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "vessel"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + score_threshold: 0.7 + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]] + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + vessel_detection: + targets: "targets" + batch_size: 8 + num_workers: 16 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + train_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "train" + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + val_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" + test_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" +trainer: + max_epochs: 200 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_vessel_detection/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_landsat_vessels/finetune_detector_nofreeze.yaml b/data/helios/v2_landsat_vessels/finetune_detector_nofreeze.yaml new file mode 100644 index 00000000..6cd1a633 --- /dev/null +++ b/data/helios/v2_landsat_vessels/finetune_detector_nofreeze.yaml @@ -0,0 +1,131 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [{PATCH_SIZE}] + num_channels: {ENCODER_EMBEDDING_SIZE} + num_classes: 2 + anchor_sizes: [[32]] + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + inputs: + image: + data_type: "raster" + layers: ["landsat"] + bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "vessel"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + score_threshold: 0.7 + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]] + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + vessel_detection: + targets: "targets" + batch_size: 8 + num_workers: 16 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + train_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "train" + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + val_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" + test_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" +trainer: + max_epochs: 200 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_vessel_detection/mAP + mode: max +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_landsat_vessels/finetune_detector_old.yaml b/data/helios/v2_landsat_vessels/finetune_detector_old.yaml new file mode 100644 index 00000000..6cd1a633 --- /dev/null +++ b/data/helios/v2_landsat_vessels/finetune_detector_old.yaml @@ -0,0 +1,131 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [{PATCH_SIZE}] + num_channels: {ENCODER_EMBEDDING_SIZE} + num_classes: 2 + anchor_sizes: [[32]] + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + inputs: + image: + data_type: "raster" + layers: ["landsat"] + bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "vessel"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + score_threshold: 0.7 + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]] + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + vessel_detection: + targets: "targets" + batch_size: 8 + num_workers: 16 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + train_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "train" + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + val_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" + test_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" +trainer: + max_epochs: 200 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_vessel_detection/mAP + mode: max +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_landsat_vessels/finetune_detector_startv1.yaml b/data/helios/v2_landsat_vessels/finetune_detector_startv1.yaml new file mode 100644 index 00000000..b7649786 --- /dev/null +++ b/data/helios/v2_landsat_vessels/finetune_detector_startv1.yaml @@ -0,0 +1,136 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_concat_finetune/crop_classify_s1s2__vessel_detect_OLD_SINGLE/checkpoints/last.ckpt + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [{PATCH_SIZE}] + num_channels: {ENCODER_EMBEDDING_SIZE} + num_classes: 2 + anchor_sizes: [[32]] + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + inputs: + image: + data_type: "raster" + layers: ["landsat"] + bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "vessel"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + score_threshold: 0.7 + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]] + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + vessel_detection: + targets: "targets" + batch_size: 8 + num_workers: 16 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + train_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "train" + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + val_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" + test_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" +trainer: + max_epochs: 200 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_vessel_detection/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_landsat_vessels/finetune_detector_startv1_constant_lr.yaml b/data/helios/v2_landsat_vessels/finetune_detector_startv1_constant_lr.yaml new file mode 100644 index 00000000..5cc805fe --- /dev/null +++ b/data/helios/v2_landsat_vessels/finetune_detector_startv1_constant_lr.yaml @@ -0,0 +1,132 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_concat_finetune/crop_classify_s1s2__vessel_detect_OLD_SINGLE/checkpoints/last.ckpt + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [{PATCH_SIZE}] + num_channels: {ENCODER_EMBEDDING_SIZE} + num_classes: 2 + anchor_sizes: [[32]] + lr: 0.00005 + plateau: false + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + inputs: + image: + data_type: "raster" + layers: ["landsat"] + bands: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "vessel"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + score_threshold: 0.7 + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]] + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + vessel_detection: + targets: "targets" + batch_size: 8 + num_workers: 16 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + train_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "train" + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image: [] + output_selector: landsat + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + landsat: ["B8", "B1", "B2", "B3", "B4", "B5", "B6", "B7", "B9", "B10", "B11"] + val_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" + test_config: + patch_size: 128 + load_all_patches: true + groups: ["labels_utm"] + tags: + split: "val" +trainer: + max_epochs: 200 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_vessel_detection/mAP + mode: max +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_nandi_crop_type/finetune_s1_s2.yaml b/data/helios/v2_nandi_crop_type/finetune_s1_s2.yaml index f626b6f7..feedff2d 100644 --- a/data/helios/v2_nandi_crop_type/finetune_s1_s2.yaml +++ b/data/helios/v2_nandi_crop_type/finetune_s1_s2.yaml @@ -192,8 +192,8 @@ data: input_mapping: crop_type_classification: label: "targets" - batch_size: 8 - num_workers: 32 + batch_size: 32 + num_workers: 16 default_config: transforms: - class_path: rslearn.train.transforms.concatenate.Concatenate diff --git a/data/helios/v2_nandi_crop_type/finetune_s1_s2_attn.yaml b/data/helios/v2_nandi_crop_type/finetune_s1_s2_attn.yaml new file mode 100644 index 00000000..a6404a24 --- /dev/null +++ b/data/helios/v2_nandi_crop_type/finetune_s1_s2_attn.yaml @@ -0,0 +1,277 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + crop_type_classification: + - class_path: rslearn.models.pooling.AttentivePool + init_args: + n_channels: {ENCODER_EMBEDDING_SIZE} + height: 1 + width: 1 + n_bandsets: 4 + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: {ENCODER_EMBEDDING_SIZE} + out_channels: 8 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/kenya_nandi/20250625 + inputs: + sentinel2_0: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_1: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_2: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_3: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_4: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_5: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_6: + data_type: "raster" + layers: ["sentinel2.6"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_7: + data_type: "raster" + layers: ["sentinel2.7"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_8: + data_type: "raster" + layers: ["sentinel2.8"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_9: + data_type: "raster" + layers: ["sentinel2.9"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_10: + data_type: "raster" + layers: ["sentinel2.10"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_11: + data_type: "raster" + layers: ["sentinel2.11"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel1_0: + data_type: "raster" + layers: ["sentinel1_ascending"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_1: + data_type: "raster" + layers: ["sentinel1_ascending.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_2: + data_type: "raster" + layers: ["sentinel1_ascending.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_3: + data_type: "raster" + layers: ["sentinel1_ascending.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_4: + data_type: "raster" + layers: ["sentinel1_ascending.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_5: + data_type: "raster" + layers: ["sentinel1_ascending.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_6: + data_type: "raster" + layers: ["sentinel1_ascending.6"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_7: + data_type: "raster" + layers: ["sentinel1_ascending.7"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_8: + data_type: "raster" + layers: ["sentinel1_ascending.8"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_9: + data_type: "raster" + layers: ["sentinel1_ascending.9"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_10: + data_type: "raster" + layers: ["sentinel1_ascending.10"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_11: + data_type: "raster" + layers: ["sentinel1_ascending.11"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + label: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + crop_type_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "category" + classes: ["Coffee", "Trees", "Grassland", "Maize", "Sugarcane", "Tea", "Water", "Built-up"] + enable_f1_metric: true + metric_kwargs: + average: "micro" + input_mapping: + crop_type_classification: + label: "targets" + batch_size: 32 + num_workers: 16 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + sentinel2_10: [] + sentinel2_11: [] + output_selector: sentinel2_l2a + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + sentinel1_10: [] + sentinel1_11: [] + output_selector: sentinel1 + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 8 + mode: "center" + image_selectors: ["sentinel2_l2a", "sentinel1"] + train_config: + groups: ["groundtruth_polygon_split_window_32", "worldcover_window_32"] + tags: + split: "train" + val_config: + groups: ["groundtruth_polygon_split_window_32", "worldcover_window_32"] + tags: + split: "val" + test_config: + groups: ["groundtruth_polygon_split_window_32", "worldcover_window_32"] + tags: + split: "val" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 4 +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_pastis/basecfg_helios_mm.yaml b/data/helios/v2_pastis/basecfg_helios_mm.yaml index 9b0d0fa2..c6020c8c 100644 --- a/data/helios/v2_pastis/basecfg_helios_mm.yaml +++ b/data/helios/v2_pastis/basecfg_helios_mm.yaml @@ -146,7 +146,7 @@ data: layers: ["label"] bands: ["class"] is_target: true - batch_size: 2 + batch_size: 16 default_config: transforms: - class_path: rslearn.train.transforms.concatenate.Concatenate diff --git a/data/helios/v2_pastis/basecfg_helios_mm_attn.yaml b/data/helios/v2_pastis/basecfg_helios_mm_attn.yaml new file mode 100644 index 00000000..cdd9d32e --- /dev/null +++ b/data/helios/v2_pastis/basecfg_helios_mm_attn.yaml @@ -0,0 +1,248 @@ +model: + init_args: + model: + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + segment: + - class_path: rslearn.models.pooling.AttentivePool + init_args: + n_channels: {ENCODER_EMBEDDING_SIZE} + height: 16 + width: 16 + n_bandsets: 4 + - class_path: rslearn.models.unet.UNetDecoder + init_args: + in_channels: [[{PATCH_SIZE}, {ENCODER_EMBEDDING_SIZE}]] + out_channels: 20 + conv_layers_per_resolution: 2 + num_channels: {8: 512, 4: 512, 2: 256, 1: 128} + - class_path: rslearn.train.tasks.segmentation.SegmentationHead +data: + init_args: + inputs: + sentinel2_0: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_1: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_2: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_3: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_4: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_5: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_6: + data_type: "raster" + layers: ["sentinel2.6"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_7: + data_type: "raster" + layers: ["sentinel2.7"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_8: + data_type: "raster" + layers: ["sentinel2.8"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_9: + data_type: "raster" + layers: ["sentinel2.9"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_10: + data_type: "raster" + layers: ["sentinel2.10"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + sentinel2_11: + data_type: "raster" + layers: ["sentinel2.11"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12"] + passthrough: true + s1d_0: + data_type: "raster" + layers: ["s1d"] + bands: ["vv", "vh"] + passthrough: true + s1d_1: + data_type: "raster" + layers: ["s1d.1"] + bands: ["vv", "vh"] + passthrough: true + s1d_2: + data_type: "raster" + layers: ["s1d.2"] + bands: ["vv", "vh"] + passthrough: true + s1d_3: + data_type: "raster" + layers: ["s1d.3"] + bands: ["vv", "vh"] + passthrough: true + s1d_4: + data_type: "raster" + layers: ["s1d.4"] + bands: ["vv", "vh"] + passthrough: true + s1d_5: + data_type: "raster" + layers: ["s1d.5"] + bands: ["vv", "vh"] + passthrough: true + s1d_6: + data_type: "raster" + layers: ["s1d.6"] + bands: ["vv", "vh"] + passthrough: true + s1d_7: + data_type: "raster" + layers: ["s1d.7"] + bands: ["vv", "vh"] + passthrough: true + s1d_8: + data_type: "raster" + layers: ["s1d.8"] + bands: ["vv", "vh"] + passthrough: true + s1d_9: + data_type: "raster" + layers: ["s1d.9"] + bands: ["vv", "vh"] + passthrough: true + s1d_10: + data_type: "raster" + layers: ["s1d.10"] + bands: ["vv", "vh"] + passthrough: true + s1d_11: + data_type: "raster" + layers: ["s1d.11"] + bands: ["vv", "vh"] + passthrough: true + targets: + data_type: "raster" + layers: ["label"] + bands: ["class"] + is_target: true + batch_size: 16 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + # PASTIS is missing B01 and B09. + # We use B02 to fill in B01 and B8A to fill in B09. + sentinel2_0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_5: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_6: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_7: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_8: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_11: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + output_selector: sentinel2_l2a + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + s1d_0: [] + s1d_1: [] + s1d_2: [] + s1d_3: [] + s1d_4: [] + s1d_5: [] + s1d_6: [] + s1d_7: [] + s1d_8: [] + s1d_9: [] + s1d_10: [] + s1d_11: [] + output_selector: sentinel1 + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + train_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + # PASTIS is missing B01 and B09. + # We use B02 to fill in B01 and B8A to fill in B09. + sentinel2_0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_5: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_6: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_7: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_8: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + sentinel2_11: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7] + output_selector: sentinel2_l2a + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + s1d_0: [] + s1d_1: [] + s1d_2: [] + s1d_3: [] + s1d_4: [] + s1d_5: [] + s1d_6: [] + s1d_7: [] + s1d_8: [] + s1d_9: [] + s1d_10: [] + s1d_11: [] + output_selector: sentinel1 + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["sentinel2_l2a", "sentinel1", "target/segment/classes", "target/segment/valid"] +trainer: + callbacks+: + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 4 diff --git a/data/helios/v2_shared/helios_apla.yaml b/data/helios/v2_shared/helios_apla.yaml new file mode 100644 index 00000000..46a46e08 --- /dev/null +++ b/data/helios/v2_shared/helios_apla.yaml @@ -0,0 +1,5 @@ +trainer: + callbacks+: + - class_path: rslearn.train.callbacks.peft.APLA + init_args: + r: 8 diff --git a/data/helios/v2_worldcereal_cropland/finetune_s1_s2.yaml b/data/helios/v2_worldcereal_cropland/finetune_s1_s2.yaml index b5811a80..4a3dca8e 100644 --- a/data/helios/v2_worldcereal_cropland/finetune_s1_s2.yaml +++ b/data/helios/v2_worldcereal_cropland/finetune_s1_s2.yaml @@ -192,8 +192,8 @@ data: input_mapping: cropland_classification: label: "targets" - batch_size: 16 - num_workers: 32 + batch_size: 32 + num_workers: 16 default_config: transforms: - class_path: rslearn.train.transforms.concatenate.Concatenate @@ -236,7 +236,7 @@ data: sentinel1: ["vv", "vh"] - class_path: rslearn.train.transforms.pad.Pad init_args: - size: 2 + size: 8 mode: "center" image_selectors: ["sentinel2_l2a", "sentinel1"] train_config: diff --git a/data/helios/v2_worldcereal_cropland/finetune_s1_s2_attn.yaml b/data/helios/v2_worldcereal_cropland/finetune_s1_s2_attn.yaml new file mode 100644 index 00000000..96630a6f --- /dev/null +++ b/data/helios/v2_worldcereal_cropland/finetune_s1_s2_attn.yaml @@ -0,0 +1,277 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + cropland_classification: + - class_path: rslearn.models.pooling.AttentivePool + init_args: + n_channels: {ENCODER_EMBEDDING_SIZE} + height: 1 + width: 1 + n_bandsets: 4 + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: {ENCODER_EMBEDDING_SIZE} + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/worldcereal_cropland/20250626 + inputs: + sentinel2_0: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_1: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_2: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_3: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_4: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_5: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_6: + data_type: "raster" + layers: ["sentinel2.6"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_7: + data_type: "raster" + layers: ["sentinel2.7"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_8: + data_type: "raster" + layers: ["sentinel2.8"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_9: + data_type: "raster" + layers: ["sentinel2.9"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_10: + data_type: "raster" + layers: ["sentinel2.10"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_11: + data_type: "raster" + layers: ["sentinel2.11"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel1_0: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_1: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_2: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_3: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_4: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_5: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_6: + data_type: "raster" + layers: ["sentinel1.6"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_7: + data_type: "raster" + layers: ["sentinel1.7"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_8: + data_type: "raster" + layers: ["sentinel1.8"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_9: + data_type: "raster" + layers: ["sentinel1.9"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_10: + data_type: "raster" + layers: ["sentinel1.10"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_11: + data_type: "raster" + layers: ["sentinel1.11"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + label: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "category" + classes: ["Cropland", "Non-Cropland"] + enable_f1_metric: true + metric_kwargs: + average: "micro" + input_mapping: + cropland_classification: + label: "targets" + batch_size: 32 + num_workers: 16 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + sentinel2_10: [] + sentinel2_11: [] + output_selector: sentinel2_l2a + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + sentinel1_10: [] + sentinel1_11: [] + output_selector: sentinel1 + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 8 + mode: "center" + image_selectors: ["sentinel2_l2a", "sentinel1"] + train_config: + groups: ["h3_sample100_66K"] + tags: + split: "train" + val_config: + groups: ["h3_sample100_66K"] + tags: + split: "val" + test_config: + groups: ["h3_sample100_66K"] + tags: + split: "val" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 4 +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_worldcereal_cropland/finetune_s1_s2_startv1.yaml b/data/helios/v2_worldcereal_cropland/finetune_s1_s2_startv1.yaml new file mode 100644 index 00000000..4514be49 --- /dev/null +++ b/data/helios/v2_worldcereal_cropland/finetune_s1_s2_startv1.yaml @@ -0,0 +1,272 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_concat_finetune/crop_classify_s1s2__vessel_detect_OLD_SINGLE/checkpoints/last.ckpt + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + cropland_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: {ENCODER_EMBEDDING_SIZE} + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/worldcereal_cropland/20250626 + inputs: + sentinel2_0: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_1: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_2: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_3: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_4: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_5: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_6: + data_type: "raster" + layers: ["sentinel2.6"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_7: + data_type: "raster" + layers: ["sentinel2.7"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_8: + data_type: "raster" + layers: ["sentinel2.8"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_9: + data_type: "raster" + layers: ["sentinel2.9"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_10: + data_type: "raster" + layers: ["sentinel2.10"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_11: + data_type: "raster" + layers: ["sentinel2.11"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel1_0: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_1: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_2: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_3: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_4: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_5: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_6: + data_type: "raster" + layers: ["sentinel1.6"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_7: + data_type: "raster" + layers: ["sentinel1.7"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_8: + data_type: "raster" + layers: ["sentinel1.8"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_9: + data_type: "raster" + layers: ["sentinel1.9"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_10: + data_type: "raster" + layers: ["sentinel1.10"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_11: + data_type: "raster" + layers: ["sentinel1.11"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + label: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "category" + classes: ["Cropland", "Non-Cropland"] + enable_f1_metric: true + metric_kwargs: + average: "micro" + input_mapping: + cropland_classification: + label: "targets" + batch_size: 16 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + sentinel2_10: [] + sentinel2_11: [] + output_selector: sentinel2_l2a + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + sentinel1_10: [] + sentinel1_11: [] + output_selector: sentinel1 + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 8 + mode: "center" + image_selectors: ["sentinel2_l2a", "sentinel1"] + train_config: + groups: ["h3_sample100_66K"] + tags: + split: "train" + val_config: + groups: ["h3_sample100_66K"] + tags: + split: "val" + test_config: + groups: ["h3_sample100_66K"] + tags: + split: "val" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v2_worldcereal_cropland/finetune_s1_s2_startv1_constant_lr.yaml b/data/helios/v2_worldcereal_cropland/finetune_s1_s2_startv1_constant_lr.yaml new file mode 100644 index 00000000..6e6cfefa --- /dev/null +++ b/data/helios/v2_worldcereal_cropland/finetune_s1_s2_startv1_constant_lr.yaml @@ -0,0 +1,268 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_concat_finetune/crop_classify_s1s2__vessel_detect_OLD_SINGLE/checkpoints/last.ckpt + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: + cropland_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: {ENCODER_EMBEDDING_SIZE} + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lr: 0.00005 + plateau: false + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: /weka/dfive-default/rslearn-eai/datasets/crop/worldcereal_cropland/20250626 + inputs: + sentinel2_0: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_1: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_2: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_3: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_4: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_5: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_6: + data_type: "raster" + layers: ["sentinel2.6"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_7: + data_type: "raster" + layers: ["sentinel2.7"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_8: + data_type: "raster" + layers: ["sentinel2.8"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_9: + data_type: "raster" + layers: ["sentinel2.9"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_10: + data_type: "raster" + layers: ["sentinel2.10"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel2_11: + data_type: "raster" + layers: ["sentinel2.11"] + bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + passthrough: true + dtype: FLOAT32 + sentinel1_0: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_1: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_2: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_3: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_4: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_5: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_6: + data_type: "raster" + layers: ["sentinel1.6"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_7: + data_type: "raster" + layers: ["sentinel1.7"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_8: + data_type: "raster" + layers: ["sentinel1.8"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_9: + data_type: "raster" + layers: ["sentinel1.9"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_10: + data_type: "raster" + layers: ["sentinel1.10"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sentinel1_11: + data_type: "raster" + layers: ["sentinel1.11"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + label: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + property_name: "category" + classes: ["Cropland", "Non-Cropland"] + enable_f1_metric: true + metric_kwargs: + average: "micro" + input_mapping: + cropland_classification: + label: "targets" + batch_size: 16 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + sentinel2_10: [] + sentinel2_11: [] + output_selector: sentinel2_l2a + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + sentinel1_10: [] + sentinel1_11: [] + output_selector: sentinel1 + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + config_fname: "/opt/helios/data/norm_configs/computed.json" + band_names: + sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"] + sentinel1: ["vv", "vh"] + - class_path: rslearn.train.transforms.pad.Pad + init_args: + size: 8 + mode: "center" + image_selectors: ["sentinel2_l2a", "sentinel1"] + train_config: + groups: ["h3_sample100_66K"] + tags: + split: "train" + val_config: + groups: ["h3_sample100_66K"] + tags: + split: "val" + test_config: + groups: ["h3_sample100_66K"] + tags: + split: "val" +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v3_multitask/OUT_medium.yaml b/data/helios/v3_multitask/OUT_medium.yaml new file mode 100644 index 00000000..4e400c78 --- /dev/null +++ b/data/helios/v3_multitask/OUT_medium.yaml @@ -0,0 +1,2023 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: + cropland_classification: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 16 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_10: [] + sentinel2_11: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_10: [] + sentinel1_11: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + - class_path: rslearn.train.transforms.pad.Pad + init_args: + image_selectors: + - sentinel2_l2a + - sentinel1 + mode: center + size: 8 + inputs: + label: + data_type: vector + is_target: true + layers: + - label + sentinel1_0: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1 + passthrough: true + sentinel1_1: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.1 + passthrough: true + sentinel1_10: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.10 + passthrough: true + sentinel1_11: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.11 + passthrough: true + sentinel1_2: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.2 + passthrough: true + sentinel1_3: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.3 + passthrough: true + sentinel1_4: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.4 + passthrough: true + sentinel1_5: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.5 + passthrough: true + sentinel1_6: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.6 + passthrough: true + sentinel1_7: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.7 + passthrough: true + sentinel1_8: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.8 + passthrough: true + sentinel1_9: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.9 + passthrough: true + num_workers: 32 + path: /weka/dfive-default/rslearn-eai/datasets/crop/worldcereal_cropland/20250626 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + cropland_classification: + label: targets + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Cropland + - Non-Cropland + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + test_config: + groups: + - h3_sample100_66K + tags: + split: val + train_config: + groups: + - h3_sample100_66K + tags: + split: train + val_config: + groups: + - h3_sample100_66K + tags: + split: val + lfmc_estimation: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 16 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_10: [] + sentinel2_11: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_10: [] + sentinel1_11: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + label: + data_type: vector + is_target: true + layers: + - label + sentinel1_0: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1 + passthrough: true + sentinel1_1: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.1 + passthrough: true + sentinel1_10: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.10 + passthrough: true + sentinel1_11: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.11 + passthrough: true + sentinel1_2: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.2 + passthrough: true + sentinel1_3: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.3 + passthrough: true + sentinel1_4: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.4 + passthrough: true + sentinel1_5: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.5 + passthrough: true + sentinel1_6: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.6 + passthrough: true + sentinel1_7: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.7 + passthrough: true + sentinel1_8: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.8 + passthrough: true + sentinel1_9: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.9 + passthrough: true + srtm: + bands: + - srtm + data_type: raster + dtype: FLOAT32 + layers: + - srtm + passthrough: true + num_workers: 32 + path: /weka/dfive-default/rslearn-eai/datasets/lfmc/20250626 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + lfmc_estimation: + label: targets + tasks: + lfmc_estimation: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + allow_invalid: true + metric_mode: l1 + property_name: lfmc_value + test_config: + groups: + - global_lfmc + tags: + split: val + train_config: + groups: + - global_lfmc + tags: + split: train + val_config: + groups: + - global_lfmc + tags: + split: val + segment: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 16 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_1: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_10: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_11: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_2: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_3: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_4: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_5: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_6: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_7: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_8: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_9: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + s1d_0: [] + s1d_1: [] + s1d_10: [] + s1d_11: [] + s1d_2: [] + s1d_3: [] + s1d_4: [] + s1d_5: [] + s1d_6: [] + s1d_7: [] + s1d_8: [] + s1d_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + s1d_0: + bands: + - vv + - vh + data_type: raster + layers: + - s1d + passthrough: true + s1d_1: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.1 + passthrough: true + s1d_10: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.10 + passthrough: true + s1d_11: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.11 + passthrough: true + s1d_2: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.2 + passthrough: true + s1d_3: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.3 + passthrough: true + s1d_4: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.4 + passthrough: true + s1d_5: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.5 + passthrough: true + s1d_6: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.6 + passthrough: true + s1d_7: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.7 + passthrough: true + s1d_8: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.8 + passthrough: true + s1d_9: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.9 + passthrough: true + targets: + bands: + - class + data_type: raster + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/pastis/rslearn_dataset/ + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + segment: + targets: targets + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + enable_miou_metric: true + metric_kwargs: + average: micro + num_classes: 20 + remap_values: + - - 0 + - 1 + - - 0 + - 255 + zero_is_invalid: true + test_config: + groups: + - fold5 + train_config: + groups: + - fold1 + - fold2 + - fold3 + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_1: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_10: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_11: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_2: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_3: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_4: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_5: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_6: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_7: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_8: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_9: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + s1d_0: [] + s1d_1: [] + s1d_10: [] + s1d_11: [] + s1d_2: [] + s1d_3: [] + s1d_4: [] + s1d_5: [] + s1d_6: [] + s1d_7: [] + s1d_8: [] + s1d_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: + - sentinel2_l2a + - sentinel1 + - target/segment/classes + - target/segment/valid + val_config: + groups: + - fold4 + vessel_classification: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 16 + default_config: + transforms: + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + label: + data_type: vector + is_target: true + layers: + - label + landsat: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + num_workers: 32 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/classifier/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_classification: + label: targets + tasks: + vessel_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + allow_invalid: true + classes: + - correct + - incorrect + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: label + skip_unknown_categories: true + test_config: + groups: + - phase2a_completed + tags: + split: val + train_config: + groups: + - selected_copy + - phase2a_completed + - phase3a_selected + tags: + split: train + val_config: + groups: + - phase2a_completed + tags: + split: val + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 16 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + num_workers: 16 +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + decoders: + cropland_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 768 + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + lfmc_estimation: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 768 + out_channels: 1 + - class_path: rslearn.train.tasks.regression.RegressionHead + segment: + - class_path: rslearn.models.unet.UNetDecoder + init_args: + conv_layers_per_resolution: 2 + in_channels: + - - 8 + - 768 + num_channels: + 1: 128 + 2: 256 + 4: 512 + 8: 512 + out_channels: 20 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + vessel_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 768 + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + cropland_classification: + label: targets + lfmc_estimation: + label: targets + segment: + targets: targets + vessel_classification: + label: targets + vessel_detection: + targets: targets + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Cropland + - Non-Cropland + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + lfmc_estimation: + class_path: rslearn.train.tasks.regression.RegressionTask + init_args: + allow_invalid: true + metric_mode: l1 + property_name: lfmc_value + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + enable_miou_metric: true + metric_kwargs: + average: micro + num_classes: 20 + remap_values: + - - 0 + - 1 + - - 0 + - 255 + zero_is_invalid: true + vessel_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + allow_invalid: true + classes: + - correct + - incorrect + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: label + skip_unknown_categories: true + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: + - model + - encoder + - 0 + unfreeze_at_epoch: 2 + max_epochs: 100 diff --git a/data/helios/v3_multitask/OUT_small_v1.yaml b/data/helios/v3_multitask/OUT_small_v1.yaml new file mode 100644 index 00000000..ed31b5cf --- /dev/null +++ b/data/helios/v3_multitask/OUT_small_v1.yaml @@ -0,0 +1,702 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: + cropland_classification: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_10: [] + sentinel2_11: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_10: [] + sentinel1_11: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + - class_path: rslearn.train.transforms.pad.Pad + init_args: + image_selectors: + - sentinel2_l2a + - sentinel1 + mode: center + size: 8 + inputs: + label: + data_type: vector + is_target: true + layers: + - label + sentinel1_0: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1 + passthrough: true + sentinel1_1: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.1 + passthrough: true + sentinel1_10: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.10 + passthrough: true + sentinel1_11: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.11 + passthrough: true + sentinel1_2: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.2 + passthrough: true + sentinel1_3: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.3 + passthrough: true + sentinel1_4: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.4 + passthrough: true + sentinel1_5: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.5 + passthrough: true + sentinel1_6: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.6 + passthrough: true + sentinel1_7: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.7 + passthrough: true + sentinel1_8: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.8 + passthrough: true + sentinel1_9: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.9 + passthrough: true + num_workers: 32 + path: /weka/dfive-default/rslearn-eai/datasets/crop/worldcereal_cropland/20250626 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + cropland_classification: + label: targets + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Cropland + - Non-Cropland + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + test_config: + groups: + - h3_sample100_66K + tags: + split: val + train_config: + groups: + - h3_sample100_66K + tags: + split: train + val_config: + groups: + - h3_sample100_66K + tags: + split: val + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + decoders: + cropland_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 768 + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + cropland_classification: + label: targets + vessel_detection: + targets: targets + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Cropland + - Non-Cropland + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: + - model + - encoder + - 0 + unfreeze_at_epoch: 2 + max_epochs: 100 diff --git a/data/helios/v3_multitask/OUT_small_v1_LARGE_LR.yaml b/data/helios/v3_multitask/OUT_small_v1_LARGE_LR.yaml new file mode 100644 index 00000000..81d335d7 --- /dev/null +++ b/data/helios/v3_multitask/OUT_small_v1_LARGE_LR.yaml @@ -0,0 +1,702 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: + cropland_classification: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_10: [] + sentinel2_11: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_10: [] + sentinel1_11: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + - class_path: rslearn.train.transforms.pad.Pad + init_args: + image_selectors: + - sentinel2_l2a + - sentinel1 + mode: center + size: 8 + inputs: + label: + data_type: vector + is_target: true + layers: + - label + sentinel1_0: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1 + passthrough: true + sentinel1_1: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.1 + passthrough: true + sentinel1_10: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.10 + passthrough: true + sentinel1_11: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.11 + passthrough: true + sentinel1_2: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.2 + passthrough: true + sentinel1_3: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.3 + passthrough: true + sentinel1_4: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.4 + passthrough: true + sentinel1_5: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.5 + passthrough: true + sentinel1_6: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.6 + passthrough: true + sentinel1_7: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.7 + passthrough: true + sentinel1_8: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.8 + passthrough: true + sentinel1_9: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.9 + passthrough: true + num_workers: 32 + path: /weka/dfive-default/rslearn-eai/datasets/crop/worldcereal_cropland/20250626 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + cropland_classification: + label: targets + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Cropland + - Non-Cropland + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + test_config: + groups: + - h3_sample100_66K + tags: + split: val + train_config: + groups: + - h3_sample100_66K + tags: + split: train + val_config: + groups: + - h3_sample100_66K + tags: + split: val + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0002 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + decoders: + cropland_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 768 + out_channels: 2 + - class_path: rslearn.train.tasks.classification.ClassificationHead + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + cropland_classification: + label: targets + vessel_detection: + targets: targets + tasks: + cropland_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Cropland + - Non-Cropland + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: + - model + - encoder + - 0 + unfreeze_at_epoch: 2 + max_epochs: 100 diff --git a/data/helios/v3_multitask/OUT_small_v2.yaml b/data/helios/v3_multitask/OUT_small_v2.yaml new file mode 100644 index 00000000..9b1888e3 --- /dev/null +++ b/data/helios/v3_multitask/OUT_small_v2.yaml @@ -0,0 +1,1007 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: + segment: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_1: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_10: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_11: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_2: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_3: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_4: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_5: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_6: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_7: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_8: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_9: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + s1d_0: [] + s1d_1: [] + s1d_10: [] + s1d_11: [] + s1d_2: [] + s1d_3: [] + s1d_4: [] + s1d_5: [] + s1d_6: [] + s1d_7: [] + s1d_8: [] + s1d_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + s1d_0: + bands: + - vv + - vh + data_type: raster + layers: + - s1d + passthrough: true + s1d_1: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.1 + passthrough: true + s1d_10: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.10 + passthrough: true + s1d_11: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.11 + passthrough: true + s1d_2: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.2 + passthrough: true + s1d_3: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.3 + passthrough: true + s1d_4: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.4 + passthrough: true + s1d_5: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.5 + passthrough: true + s1d_6: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.6 + passthrough: true + s1d_7: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.7 + passthrough: true + s1d_8: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.8 + passthrough: true + s1d_9: + bands: + - vv + - vh + data_type: raster + layers: + - s1d.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + data_type: raster + layers: + - sentinel2.9 + passthrough: true + targets: + bands: + - class + data_type: raster + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/pastis/rslearn_dataset/ + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + segment: + targets: targets + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + enable_miou_metric: true + metric_kwargs: + average: micro + num_classes: 20 + remap_values: + - - 0 + - 1 + - - 0 + - 255 + zero_is_invalid: true + test_config: + groups: + - fold5 + train_config: + groups: + - fold1 + - fold2 + - fold3 + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_1: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_10: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_11: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_2: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_3: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_4: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_5: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_6: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_7: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_8: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + sentinel2_9: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 0 + - 7 + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + s1d_0: [] + s1d_1: [] + s1d_10: [] + s1d_11: [] + s1d_2: [] + s1d_3: [] + s1d_4: [] + s1d_5: [] + s1d_6: [] + s1d_7: [] + s1d_8: [] + s1d_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: + - sentinel2_l2a + - sentinel1 + - target/segment/classes + - target/segment/valid + val_config: + groups: + - fold4 + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + decoders: + segment: + - class_path: rslearn.models.unet.UNetDecoder + init_args: + conv_layers_per_resolution: 2 + in_channels: + - - 8 + - 768 + num_channels: + 1: 128 + 2: 256 + 4: 512 + 8: 512 + out_channels: 20 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + segment: + targets: targets + vessel_detection: + targets: targets + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + enable_miou_metric: true + metric_kwargs: + average: micro + num_classes: 20 + remap_values: + - - 0 + - 1 + - - 0 + - 255 + zero_is_invalid: true + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: + - model + - encoder + - 0 + unfreeze_at_epoch: 2 + max_epochs: 100 diff --git a/data/helios/v3_multitask/README.md b/data/helios/v3_multitask/README.md new file mode 100644 index 00000000..b5f587e8 --- /dev/null +++ b/data/helios/v3_multitask/README.md @@ -0,0 +1,29 @@ +## Finetuning on concatenated datasets + +**Only one task per dataset is supported at the moment.** + +There are two stages to running a multi-dataset job: constructing a run configuration, and actually running the job. + +### 1. Constructing a run configuration + +First write a multi-dataset config like `v3_multitask/small_v1.yaml`. This will be used to generate a run-specific config passed to `rslearn`. + +Specify dataset-specific configs from `v2_*` in `dataset_cfgs`. If there are multiple config files, list them by override priority. Rather than specifying constants in a `launch_finetune` command, you must specify `patch_size`, `encoder_embedding_size`, and `helios_checkpoint_path` within the multi-dataset config itself. Be sure to also specify the `output_path` to the generated run configuration. + +The `base_cfg` key points to `base.yaml` by default, which specifies the Helios encoder backbone structure, training callbacks, etc. It can generally be kept as is, unless you need to modify the callbacks. + +Once you have constructed this multi-dataset config, you can generate the run config via + +```bash +python make_multidataset_config.py --cfg [BASE_CONFIG] +``` + +### 2. Running a multi-dataset job + +After generating a run config (see `v3_multitask/OUT*` for examples), it's straightforward to launch the multi-dataset job with `launch_finetune`. An example command is below. Note constants like `HELIOS_CHECKPOINT_PATH`, `ENCODER_EMBEDDING_SIZE`, etc. are already substituted in by `make_multidataset_config`. If you are running with a config not generated with `make_multidataset_config`, you will have to specify these constants yourself, as usual. + +```bash +python -m rslp.main helios launch_finetune --config_paths+=[CONFIG] --rslp_project [PROJECT] --experiment_id [ID] --cluster+=[CLUSTER] --image_name [IMAGE_NAME] +``` + +Optionally, specify `--local true` to run in the current Beaker session and `--do_eval true` for evaluation (only supported locally). If `RSLP_PREFIX` is not specified as an environment variable, it defaults to `project_data/` for local runs and `gs://rslearn-eai` otherwise. Please use the Beaker image `henryh/rslp_multidataset_stable`. diff --git a/data/helios/v3_multitask/base.yaml b/data/helios/v3_multitask/base.yaml new file mode 100644 index 00000000..b57df98f --- /dev/null +++ b/data/helios/v3_multitask/base.yaml @@ -0,0 +1,46 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: # Filled in by make_multidataset_config.py + lazy_decode: true + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: # Filled in by make_multidataset_config.py + +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 + +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v3_multitask/medium.yaml b/data/helios/v3_multitask/medium.yaml new file mode 100644 index 00000000..112f8a86 --- /dev/null +++ b/data/helios/v3_multitask/medium.yaml @@ -0,0 +1,14 @@ +base_cfg: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_multitask/base.yaml +dataset_cfgs: + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_landsat_vessels/finetune_detector.yaml + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_landsat_vessels/finetune_classifier.yaml + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_worldcereal_cropland/finetune_s1_s2.yaml + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_lfmc/finetune_s1_s2_srtm.yaml + - - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_pastis/basecfg.yaml + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_pastis/basecfg_helios_mm.yaml +patch_size: 8 +encoder_embedding_size: 768 +helios_checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 +output_path: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_multitask/OUT_medium.yaml +batch_size: 32 +num_workers: 16 diff --git a/data/helios/v3_multitask/small_v1.yaml b/data/helios/v3_multitask/small_v1.yaml new file mode 100644 index 00000000..4f98922c --- /dev/null +++ b/data/helios/v3_multitask/small_v1.yaml @@ -0,0 +1,10 @@ +# Small test config for multi-dataset training, landsat vessel detection and worldcereal cropland classification +base_cfg: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_multitask/base.yaml +dataset_cfgs: + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_worldcereal_cropland/finetune_s1_s2.yaml + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_landsat_vessels/finetune_detector.yaml +patch_size: 8 +encoder_embedding_size: 768 +helios_checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 +output_path: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_multitask/OUT_small_v1.yaml +batch_size: 32 diff --git a/data/helios/v3_multitask/small_v2.yaml b/data/helios/v3_multitask/small_v2.yaml new file mode 100644 index 00000000..8cd6c5e5 --- /dev/null +++ b/data/helios/v3_multitask/small_v2.yaml @@ -0,0 +1,11 @@ +# Landsat vessel detection and pastis segmentation +base_cfg: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_multitask/base.yaml +dataset_cfgs: + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_landsat_vessels/finetune_detector.yaml + - - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_pastis/basecfg.yaml + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_pastis/basecfg_helios_mm.yaml +patch_size: 8 +encoder_embedding_size: 768 +helios_checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 +output_path: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_multitask/OUT_small_v2.yaml +batch_size: 32 diff --git a/data/helios/v3_perf_benchmark/EVAL_small_multigpu_broken.yaml b/data/helios/v3_perf_benchmark/EVAL_small_multigpu_broken.yaml new file mode 100644 index 00000000..98bcf21d --- /dev/null +++ b/data/helios/v3_perf_benchmark/EVAL_small_multigpu_broken.yaml @@ -0,0 +1,238 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + num_workers: 16 + data_modules: + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 8 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_concat_finetune/crop_classify_s1s2__vessel_detect__multigpu/checkpoints/last.ckpt + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + max_epochs: 100 diff --git a/data/helios/v3_perf_benchmark/EVAL_small_v1.yaml b/data/helios/v3_perf_benchmark/EVAL_small_v1.yaml new file mode 100644 index 00000000..837d2bc8 --- /dev/null +++ b/data/helios/v3_perf_benchmark/EVAL_small_v1.yaml @@ -0,0 +1,238 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + num_workers: 16 + data_modules: + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 8 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_concat_finetune/crop_classify_s1s2__vessel_detect/checkpoints/last.ckpt + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + max_epochs: 100 diff --git a/data/helios/v3_perf_benchmark/EVAL_vessel_detection.yaml b/data/helios/v3_perf_benchmark/EVAL_vessel_detection.yaml new file mode 100644 index 00000000..ba9d3989 --- /dev/null +++ b/data/helios/v3_perf_benchmark/EVAL_vessel_detection.yaml @@ -0,0 +1,238 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + num_workers: 16 + data_modules: + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 8 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + checkpoint_path: /weka/dfive-default/ryanp/rslearn_projects/project_data/projects/helios_cross_finetuning_v4/v2_landsat_vessels_finetune_detector__CROSS__v2_base/checkpoints/last.ckpt + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + max_epochs: 100 diff --git a/data/helios/v3_perf_benchmark/OUT_crop_type.yaml b/data/helios/v3_perf_benchmark/OUT_crop_type.yaml new file mode 100644 index 00000000..8262162f --- /dev/null +++ b/data/helios/v3_perf_benchmark/OUT_crop_type.yaml @@ -0,0 +1,524 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: + crop_type_classification: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel2_l2a + selections: + sentinel2_0: [] + sentinel2_1: [] + sentinel2_10: [] + sentinel2_11: [] + sentinel2_2: [] + sentinel2_3: [] + sentinel2_4: [] + sentinel2_5: [] + sentinel2_6: [] + sentinel2_7: [] + sentinel2_8: [] + sentinel2_9: [] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: sentinel1 + selections: + sentinel1_0: [] + sentinel1_1: [] + sentinel1_10: [] + sentinel1_11: [] + sentinel1_2: [] + sentinel1_3: [] + sentinel1_4: [] + sentinel1_5: [] + sentinel1_6: [] + sentinel1_7: [] + sentinel1_8: [] + sentinel1_9: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + sentinel1: + - vv + - vh + sentinel2_l2a: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + config_fname: /opt/helios/data/norm_configs/computed.json + - class_path: rslearn.train.transforms.pad.Pad + init_args: + image_selectors: + - sentinel2_l2a + - sentinel1 + mode: center + size: 8 + inputs: + label: + data_type: vector + is_target: true + layers: + - label + sentinel1_0: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending + passthrough: true + sentinel1_1: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.1 + passthrough: true + sentinel1_10: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.10 + passthrough: true + sentinel1_11: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.11 + passthrough: true + sentinel1_2: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.2 + passthrough: true + sentinel1_3: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.3 + passthrough: true + sentinel1_4: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.4 + passthrough: true + sentinel1_5: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.5 + passthrough: true + sentinel1_6: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.6 + passthrough: true + sentinel1_7: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.7 + passthrough: true + sentinel1_8: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.8 + passthrough: true + sentinel1_9: + bands: + - vv + - vh + data_type: raster + dtype: FLOAT32 + layers: + - sentinel1_ascending.9 + passthrough: true + sentinel2_0: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2 + passthrough: true + sentinel2_1: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.1 + passthrough: true + sentinel2_10: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.10 + passthrough: true + sentinel2_11: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.11 + passthrough: true + sentinel2_2: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.2 + passthrough: true + sentinel2_3: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.3 + passthrough: true + sentinel2_4: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.4 + passthrough: true + sentinel2_5: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.5 + passthrough: true + sentinel2_6: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.6 + passthrough: true + sentinel2_7: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.7 + passthrough: true + sentinel2_8: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.8 + passthrough: true + sentinel2_9: + bands: + - B02 + - B03 + - B04 + - B08 + - B05 + - B06 + - B07 + - B8A + - B11 + - B12 + - B01 + - B09 + data_type: raster + dtype: FLOAT32 + layers: + - sentinel2.9 + passthrough: true + num_workers: 32 + path: /weka/dfive-default/rslearn-eai/datasets/crop/kenya_nandi/20250625 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + crop_type_classification: + label: targets + tasks: + crop_type_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Coffee + - Trees + - Grassland + - Maize + - Sugarcane + - Tea + - Water + - Built-up + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category + test_config: + groups: + - groundtruth_polygon_split_window_32 + - worldcover_window_32 + tags: + split: val + train_config: + groups: + - groundtruth_polygon_split_window_32 + - worldcover_window_32 + tags: + split: train + val_config: + groups: + - groundtruth_polygon_split_window_32 + - worldcover_window_32 + tags: + split: val + num_workers: 32 +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + decoders: + crop_type_classification: + - class_path: rslearn.models.pooling_decoder.PoolingDecoder + init_args: + in_channels: 768 + out_channels: 8 + - class_path: rslearn.train.tasks.classification.ClassificationHead + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + crop_type_classification: + label: targets + tasks: + crop_type_classification: + class_path: rslearn.train.tasks.classification.ClassificationTask + init_args: + classes: + - Coffee + - Trees + - Grassland + - Maize + - Sugarcane + - Tea + - Water + - Built-up + enable_f1_metric: true + metric_kwargs: + average: micro + property_name: category +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 + max_epochs: 100 diff --git a/data/helios/v3_perf_benchmark/OUT_vessel_detection.yaml b/data/helios/v3_perf_benchmark/OUT_vessel_detection.yaml new file mode 100644 index 00000000..18b8f011 --- /dev/null +++ b/data/helios/v3_perf_benchmark/OUT_vessel_detection.yaml @@ -0,0 +1,244 @@ +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: + vessel_detection: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + batch_size: 8 + default_config: + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + inputs: + image: + bands: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + data_type: raster + dtype: FLOAT32 + layers: + - landsat + passthrough: true + mask: + bands: + - mask + data_type: raster + dtype: INT32 + is_target: true + layers: + - mask + passthrough: true + targets: + data_type: vector + is_target: true + layers: + - label + num_workers: 16 + path: /weka/dfive-default/rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20250624 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 + test_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + train_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: train + transforms: + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + output_selector: landsat + selections: + image: [] + - class_path: rslp.helios.norm.HeliosNormalize + init_args: + band_names: + landsat: + - B8 + - B1 + - B2 + - B3 + - B4 + - B5 + - B6 + - B7 + - B9 + - B10 + - B11 + config_fname: /opt/helios/data/norm_configs/computed.json + val_config: + groups: + - labels_utm + load_all_patches: true + patch_size: 128 + tags: + split: val + num_workers: 16 +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + lr: 0.0001 + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + decoders: + vessel_detection: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + anchor_sizes: + - - 32 + downsample_factors: + - 8 + num_channels: 768 + num_classes: 2 + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 + forward_kwargs: + patch_size: 8 + selector: + - encoder + lazy_decode: true + plateau: true + plateau_cooldown: 10 + plateau_factor: 0.2 + plateau_min_lr: 0 + plateau_patience: 2 + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + input_mapping: + vessel_detection: + targets: targets + tasks: + vessel_detection: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + box_size: 15 + classes: + - unknown + - vessel + enable_f1_metric: true + enable_map_metric: true + exclude_by_center: true + f1_metric_kwargs: + cmp_mode: distance + cmp_threshold: 15 + flatten_classes: true + f1_metric_thresholds: + - - 0.05 + - 0.1 + - 0.2 + - 0.3 + - 0.4 + - 0.5 + - 0.6 + - 0.7 + - 0.8 + - 0.9 + - 0.95 + property_name: category + remap_values: + - - 0 + - 1 + - - 0 + - 255 + score_threshold: 0.7 +rslp_experiment: placeholder +rslp_project: placeholder +trainer: + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + mode: min + monitor: val_loss + save_last: true + save_top_k: 1 + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: + - model + - encoder + - 0 + unfreeze_at_epoch: 2 + max_epochs: 100 diff --git a/data/helios/v3_perf_benchmark/base.yaml b/data/helios/v3_perf_benchmark/base.yaml new file mode 100644 index 00000000..b57df98f --- /dev/null +++ b/data/helios/v3_perf_benchmark/base.yaml @@ -0,0 +1,46 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslp.helios.model.Helios + init_args: + checkpoint_path: "{CHECKPOINT_PATH}" + selector: ["encoder"] + forward_kwargs: + patch_size: {PATCH_SIZE} + decoders: # Filled in by make_multidataset_config.py + lazy_decode: true + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + +data: + class_path: rslearn.train.data_module.MultiDatasetDataModule + init_args: + data_modules: # Filled in by make_multidataset_config.py + +trainer: + max_epochs: 100 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_loss + mode: min + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0] + unfreeze_at_epoch: 2 + +rslp_project: placeholder +rslp_experiment: placeholder diff --git a/data/helios/v3_perf_benchmark/crop_type.yaml b/data/helios/v3_perf_benchmark/crop_type.yaml new file mode 100644 index 00000000..e8b5404a --- /dev/null +++ b/data/helios/v3_perf_benchmark/crop_type.yaml @@ -0,0 +1,9 @@ +base_cfg: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_perf_benchmark/base.yaml +dataset_cfgs: + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_nandi_crop_type/finetune_s1_s2.yaml +patch_size: 8 +encoder_embedding_size: 768 +helios_checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 +output_path: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_perf_benchmark/OUT_crop_type.yaml +num_workers: 32 +batch_size: 8 diff --git a/data/helios/v3_perf_benchmark/vessel_detection.yaml b/data/helios/v3_perf_benchmark/vessel_detection.yaml new file mode 100644 index 00000000..defc3aee --- /dev/null +++ b/data/helios/v3_perf_benchmark/vessel_detection.yaml @@ -0,0 +1,10 @@ +# Benchmark performance via single-task (vessel detection) training +base_cfg: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_perf_benchmark/base.yaml +dataset_cfgs: + - /weka/dfive-default/ryanp/rslearn_projects/data/helios/v2_landsat_vessels/finetune_detector_old.yaml +patch_size: 8 +encoder_embedding_size: 768 +helios_checkpoint_path: /weka/dfive-default/helios/checkpoints/favyen/v0.2_base_latent_mim_128_alldata_random_fixed_modality_0.5/step320000 +output_path: /weka/dfive-default/ryanp/rslearn_projects/data/helios/v3_perf_benchmark/OUT_vessel_detection.yaml +batch_size: 8 +num_workers: 16 diff --git a/helios.Dockerfile b/helios.Dockerfile index 7e5f699d..d6a55e23 100644 --- a/helios.Dockerfile +++ b/helios.Dockerfile @@ -4,10 +4,12 @@ RUN apt update RUN apt install -y libpq-dev ffmpeg libsm6 libxext6 git wget # Install rslearn and helios (need to be in local directory). -COPY ./rslearn /opt/rslearn -COPY ./helios /opt/helios +COPY ./docker_build/rslearn /opt/rslearn +COPY ./docker_build/helios /opt/helios COPY requirements.txt /opt/rslearn_projects/requirements.txt -RUN pip install --no-cache-dir --upgrade /opt/rslearn[extra] /opt/helios -r /opt/rslearn_projects/requirements.txt +RUN pip install --no-cache-dir geobench==0.0.1 +RUN pip install --no-cache-dir --upgrade /opt/helios +RUN pip install --no-cache-dir --upgrade /opt/rslearn[extra] -r /opt/rslearn_projects/requirements.txt # Copy rslearn_projects and install it too. COPY . /opt/rslearn_projects/ diff --git a/rslp/helios/README.md b/rslp/helios/README.md index a243910e..5711bd68 100644 --- a/rslp/helios/README.md +++ b/rslp/helios/README.md @@ -13,8 +13,11 @@ the list of model configuration files that will be used as templates for the fine-tuning experiments). `--configs` is similarly optional. If you need to create a new image, first create a copy of `rslearn_projects` repository -with subfolders `rslearn` (containing https://github.com/allenai/rslearn) and -`helios` (containing https://github.com/allenai/helios). Then run: +with subfolders `docker_build/rslearn` (containing https://github.com/allenai/rslearn) and +`docker_build/helios` (containing https://github.com/allenai/helios). Then run: docker build -t rslphelios -f helios.Dockerfile . beaker image create --name rslphelios rslphelios + +You may need to remove the version specification on `beaker-py` in `helios/requirements.txt`, this is due to `olmo-core` (and so +`helios` as well) requiring an outdated version of `beaker-py` imcompatible with `rslearn`. diff --git a/rslp/helios/launch_finetune.py b/rslp/helios/launch_finetune.py index 50041bd7..805329b0 100644 --- a/rslp/helios/launch_finetune.py +++ b/rslp/helios/launch_finetune.py @@ -6,44 +6,57 @@ import tempfile from pathlib import Path +import yaml + from rslp.log_utils import get_logger +DEFAULT_RSLP_PREFIX = "project_data/" +DEFAULT_RSLP_BUCKET = "gs://rslearn-eai" DEFAULT_RSLP_PROJECT = "helios_finetuning" CONFIG_BASE_DIR = Path("data/helios") +EVAL_BASE_DIR = "helios/eval_sweeps" logger = get_logger(__name__) def launch_finetune( - helios_checkpoint_path: str, experiment_id: str, - image_name: str, - encoder_embedding_size: int, - patch_size: int, - cluster: list[str], - config_paths: list[str], + config_paths: "list[str]", + image_name: str | None = None, + cluster: list[str] | None = None, + helios_checkpoint_path: str | None = None, + encoder_embedding_size: int | None = None, + patch_size: int | None = None, rslp_project: str = DEFAULT_RSLP_PROJECT, gpus: int = 1, priority: str = "high", retries: int = 0, mode: str = "fit", + profiler: str | None = None, + local: bool = False, + do_eval: bool = False, ) -> None: """Launch Helios fine-tuning experiments. Args: - helios_checkpoint_path: path to Helios checkpoint to fine-tune from. experiment_id: the experiment name. - image_name: what Beaker image to use. - encoder_embedding_size: the embedding size of the encoder. - patch_size: the patch size to use. - cluster: see beaker_train. config_paths: list of configuration files to use. Later config files override earlier configs in the list. - rslp_project: optional override for W&B project to use. - gpus: how many GPUs to assign in the Beaker job. - priority: what priority to use. - retries: Beaker job retries. + image_name: what Beaker image to use. Must be specified if not local. + cluster: see beaker_train. Must be specified if not local. + helios_checkpoint_path: path to Helios checkpoint to fine-tune from. If none, assume + it's already specified in the config. + encoder_embedding_size: the embedding size of the encoder. If none, assume + it's already specified in the config. + patch_size: the patch size to use. If none, assume it's already specified in the config. + rslp_project: optional override for W&B project to use. By default, uses DEFAULT_RSLP_PROJECT. + gpus: how many GPUs to assign in the Beaker job. By default, uses 1. + priority: what priority to use. By default, uses "high". + retries: Beaker job retries. By default, uses 0. mode: Mode to run the model ('fit', 'validate', 'test', or 'predict'). + profiler: Profiler to use for training. Can be 'simple' or 'advanced' or None. + local: Whether to run the command locally instead of spawning a Beaker job. + do_eval: Whether to just run evals. """ # Go into each config file (including the base ones) and make replacements as # needed. @@ -51,60 +64,159 @@ def launch_finetune( # command-line since it appears in a list, so instead we create a copy # of all these configuration files in a temporary directory. with tempfile.TemporaryDirectory(dir=".") as tmp_dir: + weka_mounts = [ + dict(bucket_name="dfive-default", mount_path="/weka/dfive-default") + ] + full_eval_dir = os.path.join(weka_mounts[0]["mount_path"], EVAL_BASE_DIR) + os.makedirs(full_eval_dir, exist_ok=True) + # Need to use relative path from rslearn_projects folder since the config file # will be copied into the Beaker experiment's rslearn_projects copy. tmp_dir = os.path.relpath(tmp_dir) tmp_config_fnames: list[str] = [] for config_idx, cur_config_fname in enumerate(config_paths): + # Load the config file as string for template substitution with open(cur_config_fname) as f: config_str = f.read() - config_str = config_str.replace("{CHECKPOINT_PATH}", helios_checkpoint_path) - config_str = config_str.replace("{PATCH_SIZE}", str(patch_size)) - config_str = config_str.replace("{256/PATCH_SIZE}", str(256 // patch_size)) - config_str = config_str.replace("{128/PATCH_SIZE}", str(128 // patch_size)) - config_str = config_str.replace( - "{ENCODER_EMBEDDING_SIZE}", str(encoder_embedding_size) - ) + if helios_checkpoint_path is not None: + config_str = config_str.replace( + "{CHECKPOINT_PATH}", helios_checkpoint_path + ) + if patch_size is not None: + config_str = config_str.replace("{PATCH_SIZE}", str(patch_size)) + config_str = config_str.replace( + "{256/PATCH_SIZE}", str(256 // patch_size) + ) + config_str = config_str.replace( + "{128/PATCH_SIZE}", str(128 // patch_size) + ) + if encoder_embedding_size is not None: + config_str = config_str.replace( + "{ENCODER_EMBEDDING_SIZE}", str(encoder_embedding_size) + ) + + # String to yaml to add test metrics file key + config = yaml.safe_load(config_str) + if do_eval and "model" in config and "init_args" in config["model"]: + if helios_checkpoint_path is not None: + model_name = "_".join( + helios_checkpoint_path.split(os.path.sep)[-2:] + ) # "modelname_stepX" + else: + model_path = config["model"]["init_args"]["model"]["init_args"][ + "checkpoint_path" + ] + model_path_parts = model_path.split(os.path.sep) + model_name = ( + model_path_parts[-3] + + "_" + + model_path_parts[-1].replace(".ckpt", "") + ) + eval_task = "__".join(config_paths[0].split(os.path.sep)[-2:]).strip( + ".yaml" + ) + path = os.path.join(full_eval_dir, f"{model_name}__{eval_task}.json") + config["model"]["init_args"]["metrics_file"] = path + logger.info(f"Saving test metrics to {path}") + + # Save the config file to the temporary directory tmp_config_fname = os.path.join( tmp_dir, f"{experiment_id}_{config_idx}.yaml" ) with open(tmp_config_fname, "w") as f: - f.write(config_str) + yaml.dump(config, f, default_flow_style=False) tmp_config_fnames.append(tmp_config_fname) - weka_mounts = [ - dict(bucket_name="dfive-default", mount_path="/weka/dfive-default") - ] + if local: + # If running locally, assume we are in a gpu session + # NOTE: assuming that all the args are passed through to the config file and do NOT get + # passed through the final call to rslp.rslearn_main (except for profiler) + if "RSLP_PREFIX" not in os.environ: + os.environ["RSLP_PREFIX"] = DEFAULT_RSLP_PREFIX + logger.info(f"Using {DEFAULT_RSLP_PREFIX} as default RSLP_PREFIX") + args = [ + "python", + "-m", + "rslp.rslearn_main", + "model", + "fit" if not do_eval else "validate", + ] + paths = [] + for i, _ in enumerate(config_paths): + args.append("--config") + path = f"{tmp_dir}/{experiment_id}_{i}.yaml" + paths.append(path) + args.append(path) - # OK now we can prepare all the command-line arguments to beaker_train. - args = [ - "python", - "-m", - "rslp.main", - "common", - "beaker_train", - "--mode", - mode, - "--config_paths", - json.dumps(tmp_config_fnames), - "--image_name", - image_name, - "--cluster", - json.dumps(cluster), - "--weka_mounts", - json.dumps(weka_mounts), - "--gpus", - str(gpus), - "--project_id", - rslp_project, - "--experiment_id", - experiment_id, - "--priority", - priority, - "--retries", - str(retries), - ] - logger.info(f"Launching job by running: {args}") - subprocess.check_call(args) # nosec + args.extend( + ["--rslp_experiment", experiment_id, "--rslp_project", rslp_project] + ) + + if profiler: + args.append("--profiler") + args.append(profiler) + args.append("--autoresume=true") + + print("=" * 80) + print("DEBUG: Command being spawned:") + print(" ".join(args)) + print("=" * 80) + + # Monkeypatch paths that are hardcoded in the config files + for path in paths: + with open(path) as f: + string = f.read() + string = string.replace("/opt/", "./docker_build/") + with open(path, "w") as f: + f.write(string) + + subprocess.check_call(args) # nosec + + else: + if do_eval: + raise NotImplementedError("Eval mode not supported for Beaker job") + if image_name is None: + raise ValueError("image_name must be specified if not local") + if cluster is None: + raise ValueError("cluster must be specified if not local") + if "RSLP_PREFIX" not in os.environ: + os.environ["RSLP_PREFIX"] = DEFAULT_RSLP_BUCKET + logger.info(f"Using {DEFAULT_RSLP_BUCKET} as default RSLP_PREFIX") + + extra_args = [] + if profiler: + extra_args.extend(["--profiler", profiler]) + + args = [ + "python", + "-m", + "rslp.main", + "common", + "beaker_train", + "--mode", + mode, + "--config_paths", + json.dumps(tmp_config_fnames), + "--image_name", + image_name, + "--cluster", + json.dumps(cluster), + "--weka_mounts", + json.dumps(weka_mounts), + "--gpus", + str(gpus), + "--project_id", + rslp_project, + "--experiment_id", + experiment_id, + "--priority", + priority, + "--retries", + str(retries), + ] + if extra_args: + args.extend(["--extra_args", json.dumps(extra_args)]) + logger.info(f"Launching job by running: {args}") + subprocess.check_call(args) # nosec diff --git a/rslp/helios/model.py b/rslp/helios/model.py index 8b159181..ea3a205e 100644 --- a/rslp/helios/model.py +++ b/rslp/helios/model.py @@ -1,6 +1,7 @@ """Helios model wrapper for fine-tuning in rslearn.""" import json +import os from contextlib import nullcontext from typing import Any @@ -11,6 +12,7 @@ from helios.train.masking import MaskedHeliosSample, MaskValue from olmo_core.config import Config from olmo_core.distributed.checkpoint import load_model_and_optim_state +from rslearn.train.lightning_module import RestoreConfig from rslp.log_utils import get_logger @@ -83,7 +85,27 @@ def __init__( # Load the checkpoint. if not random_initialization: train_module_dir = f"{checkpoint_path}/model_and_optim" - load_model_and_optim_state(train_module_dir, model) + if os.path.exists(train_module_dir): + load_model_and_optim_state(train_module_dir, model) + print( + f"INFO: loaded helios encoder from {train_module_dir}/model_and_optim" + ) + + else: + # If we load last.ckpt, we are loading from sft, so ignore decoder weights. + restore_config = RestoreConfig( + restore_path=os.path.join(checkpoint_path, "last.ckpt"), + selector=["state_dict"], + ignore_prefixes=["model.decoders."], + remap_prefixes=[("model.encoder.0.model.", "encoder.")], + ) + state_dict = restore_config.get_state_dict() + result = model.load_state_dict(state_dict, strict=False) + if result.missing_keys: + print(f"WARNING: missing keys: {result.missing_keys}") + if result.unexpected_keys: + print(f"WARNING: unexpected keys: {result.unexpected_keys}") + print(f"INFO: loaded helios encoder from {checkpoint_path}/last.ckpt") # Select just the portion of the model that we actually want to use. for part in selector: @@ -157,18 +179,13 @@ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]: sample, always_pass_none_mask_to_transformer=True, **self.forward_kwargs )[0] - # Apply temporal/modality pooling so we just have one feature per patch. - features = [] - for modality in present_modalities: - modality_features = getattr(tokens_and_masks, modality) - # Pool over band sets and timesteps (BHWTSC -> BHWC). - pooled = modality_features.mean(dim=[3, 4]) - # We want BHWC -> BCHW. - pooled = rearrange(pooled, "b h w c -> b c h w") - features.append(pooled) - # Pool over the modalities, so we get one BCHW feature map. - pooled = torch.stack(features, dim=0).mean(dim=0) - return [pooled] + # Fuse modality and bandset dimensions, leave all other dimensions as is + features_list = [ + getattr(tokens_and_masks, modality) for modality in present_modalities + ] + features = torch.cat(features_list, dim=4) # B x H x W x T x M x C + features = features.permute(4, 0, 5, 1, 2, 3) # M x B x C x H x W x T + return [features] def get_backbone_channels(self) -> list: """Returns the output channels of this model when used as a backbone. diff --git a/rslp/helios/scripts/eval_cross_finetuning.py b/rslp/helios/scripts/eval_cross_finetuning.py new file mode 100644 index 00000000..20b2022a --- /dev/null +++ b/rslp/helios/scripts/eval_cross_finetuning.py @@ -0,0 +1,106 @@ +"""Evaluate cross-finetuned models on various tasks. + +This script evaluates models that have been cross-finetuned on different tasks +by running evaluation commands for each model-task combination. +""" + +import os + +ckpt_dir = "/weka/dfive-default/helios/checkpoints/ryanp/" +config_dir = "/weka/dfive-default/ryanp/rslearn_projects/data/helios" +root_dir = "/weka/dfive-default/ryanp/rslearn_projects" +token = "__CROSS__" # nosec +patch_size = 8 +encoder_embedding_size = 768 +image_name = "favyen/rslphelios3" +tasks = [ + ("v2_landsat_vessels", "finetune_classifier"), + ("v2_landsat_vessels", "finetune_detector"), + ("v2_nandi_crop_type", "finetune_s2"), + ("v2_worldcereal_cropland", "finetune_s2"), +] +task_name_map = { + "v2_landsat_vessel_classification": tasks[0], + "v2_landsat_vessel_detection": tasks[1], + "v2_crop_type_classification": tasks[2], + "v2_cropland_classification": tasks[3], +} +cmd = """ +python -m rslp.main helios launch_finetune \ + --helios_checkpoint_path {helios_checkpoint_path} \ + --patch_size {patch_size} \ + --encoder_embedding_size {encoder_embedding_size} \ + --image_name {image_name} \ + --cluster+=ai2/ceres-cirrascale \ + --config_paths+={config_path} \ + --experiment_id {run}__{task}__{cfg} \ + --rslp_project helios_evals \ + --local true \ + --do_eval true \ +""" + + +def find_task_and_cfg(task_and_cfg: str) -> tuple[tuple[str, str], str]: + """Find the task and configuration from a task_and_cfg string. + + Args: + task_and_cfg: String containing task and configuration information. + + Returns: + Tuple of (task_tuple, cfg) where task_tuple is (task_name, cfg_name). + + Raises: + ValueError: If no matching task is found. + """ + for task in tasks: + if task[0] in task_and_cfg and task[1] in task_and_cfg: + cfg = task_and_cfg.replace(f"{task[0]}_{task[1]}_", "") + return task, cfg + raise ValueError(f"No matching task found for {task_and_cfg}") + + +def get_base_task(run: str) -> str: + """Get base finetuning task from run name. + + Args: + run: The run name to extract the base task from. + + Returns: + The base task name. + """ + s = run.split(token)[1] + search_str = "_helios_" + if search_str not in s: + return run + return s[: s.find(search_str)] + + +cmds = [] +for run in os.listdir(ckpt_dir): + if token in run: + ckpt_path = os.path.join(ckpt_dir, run) + for task, cfg in tasks: + task_matches_base = task_name_map.get(get_base_task(run)) == (task, cfg) + task_matches_finetune = f"{task}_{cfg}" == run.split(token)[0] + if task_matches_base or task_matches_finetune: + config_path = os.path.join(config_dir, task, cfg + ".yaml") + filled_cmd = cmd.format( + helios_checkpoint_path=ckpt_path, + patch_size=patch_size, + encoder_embedding_size=encoder_embedding_size, + image_name=image_name, + config_path=config_path, + run=run, + task=task, + cfg=cfg, + ) + print(filled_cmd) + print() + cmds.append(filled_cmd) + +print(len(cmds)) +input("Continue? ") + +os.chdir(root_dir) +for cmd in cmds: + os.system(cmd) # nosec diff --git a/rslp/helios/scripts/get_finetuned_models.py b/rslp/helios/scripts/get_finetuned_models.py new file mode 100644 index 00000000..c4c992c1 --- /dev/null +++ b/rslp/helios/scripts/get_finetuned_models.py @@ -0,0 +1,177 @@ +"""Get the best finetuned models from W&B and download them to a local directory. + +Usage: python3 get_finetuned_models.py --workspace ??? --project ??? --entity ??? +""" + +import argparse +import json +import os +import shutil +import subprocess # nosec +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + +import wandb + + +def get_best_runs( + project: str, entity: str, download_all: bool = False +) -> dict[str, list[dict[str, Any]]]: + """Get the best runs from W&B for a given project and entity. + + Args: + project: W&B project name. + entity: W&B entity name. + download_all: Whether to download all models or just the best ones. + + Returns: + Dictionary mapping metric names to lists of run information. + """ + api = wandb.Api() # type: ignore + runs = api.runs(f"{entity}/{project}") + + metric_names = set() + for run in runs: + for metric in run.summary.keys(): + try: + metric_k, metric_v = metric.split("/") + except ValueError: + continue + if ( + metric_k.startswith("val_") + and ( + metric_k.endswith("_detection") + or metric_k.endswith("_classification") + ) + and metric_v in ("accuracy", "F1") + ): + metric_names.add(metric) + + saved_runs: dict[str, list[dict[str, Any]]] = {} + for metric in metric_names: + for run in runs: + dirpath = run.config["trainer"]["callbacks"][1]["init_args"]["dirpath"] + pretrained = run.config["model"]["init_args"]["model"]["init_args"][ + "encoder" + ][0]["init_args"]["checkpoint_path"] + info = { + "metric": metric, + "run_name": run.name, + "dirpath": dirpath, + "pretrained": pretrained, + } + try: + saved_runs[metric].append(info) + except KeyError: + saved_runs[metric] = [info] + + if not download_all: + for metric, run_infos in saved_runs.items(): + best_run = max(run_infos, key=lambda x: x["metric"]) + saved_runs[metric] = [best_run] + + return saved_runs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--project", type=str, required=True, help="WandB project") + parser.add_argument("--entity", type=str, required=True, help="WandB entity") + parser.add_argument( + "--download_all", + action="store_true", + help="Download all models, not just the best ones", + ) + args = parser.parse_args() + + user = ( + subprocess.check_output( # nosec + "beaker account whoami --format json | jq -r '.[0].name'", + shell=True, # nosec + ) + .decode("utf-8") + .strip() + ) + local_dir = f"/weka/dfive-default/helios/checkpoints/{user}/" + os.makedirs(local_dir, exist_ok=True) + print(f"Downloading finetuned models to {local_dir}") + + runs = get_best_runs(args.project, args.entity, args.download_all) + print(json.dumps(runs, indent=4)) + + def download_blob_async(blob_path: str, local_path: str, metric_name: str) -> bool: + """Download a single blob asynchronously.""" + try: + download_cmd = ["gcloud", "storage", "cp", blob_path, local_path] + print(f"[{metric_name}] Downloading: {' '.join(download_cmd)}") + subprocess.run(download_cmd, check=True) # nosec + print(f"[{metric_name}] Downloaded {blob_path} to {local_path}") + return True + except subprocess.CalledProcessError as e: + print(f"[{metric_name}] Error downloading {blob_path}: {e}") + return False + except Exception as e: + print(f"[{metric_name}] Unexpected error downloading {blob_path}: {e}") + return False + + # Collect all download tasks + download_tasks = [] + checkpoint = "last.ckpt" + for metric, run_infos in runs.items(): + for run_info in run_infos: + gs_path = run_info["dirpath"] + bucket_name = gs_path.split("/")[2] + blob_path = os.path.join("/".join(gs_path.split("/")[3:]), checkpoint) + + print(f"Processing {metric}: {gs_path}") + + try: + # Construct full GCS path + full_blob_path = f"gs://{bucket_name}/{blob_path}" + local_path = os.path.join(local_dir, run_info["run_name"], checkpoint) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + # Add to download tasks + download_tasks.append((full_blob_path, local_path, metric)) + + # Also, copy the config.json file from "pretrained" key in run_info + config_dest = local_path.replace(checkpoint, "config.json") + config_src = os.path.join(run_info["pretrained"], "config.json") + print(f"Copying config.json from {config_src} to {config_dest}") + shutil.copy(config_src, config_dest) + + except Exception as e: + print(f"Unexpected error processing {metric}: {e}") + + # Execute all downloads in parallel + print(f"\nStarting parallel download of {len(download_tasks)} files...") + max_workers = min(10, len(download_tasks)) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_task = { + executor.submit(download_blob_async, blob_path, local_path, metric): ( + blob_path, + local_path, + metric, + ) + for blob_path, local_path, metric in download_tasks + } + + completed = 0 + failed = 0 + for future in as_completed(future_to_task): + blob_path, local_path, metric = future_to_task[future] + try: + success = future.result() + if success: + completed += 1 + else: + failed += 1 + print( + f"Progress: {completed + failed}/{len(download_tasks)} completed (success: {completed}, failed: {failed})" + ) + except Exception as e: + failed += 1 + print(f"Exception in download task for {blob_path}: {e}") + + print(f"\nDownload complete! Success: {completed}, Failed: {failed}") diff --git a/rslp/helios/scripts/launch_cross_finetuning.py b/rslp/helios/scripts/launch_cross_finetuning.py new file mode 100644 index 00000000..7a2ce404 --- /dev/null +++ b/rslp/helios/scripts/launch_cross_finetuning.py @@ -0,0 +1,140 @@ +"""Launch cross-finetuning experiments for Helios models. + +This script launches multiple finetuning experiments by cross-training models +on different benchmarks and configurations. +""" + +import concurrent.futures +import os +import subprocess # nosec + +cmd1 = """ +RSLP_PREFIX=gs://rslearn-eai \ + python -m rslp.main helios launch_finetune \ + --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/ryanp/{checkpoint} \ + --patch_size {patch_size} \ + --encoder_embedding_size {embedding_size} \ + --image_name {image_name} \ + --config_paths+=data/helios/{benchmark}/basecfg.yaml \ + --config_paths+=data/helios/{benchmark}/basecfg_helios_mm.yaml \ + --config_paths+=data/helios/v2_shared/helios_freeze_then_lowlr.yaml \ + --cluster+=ai2/ceres-cirrascale \ + --cluster+=ai2/saturn-cirrascale \ + --rslp_project {rslp_project} \ + --experiment_id {experiment_id} \ + --priority {priority} +""" + +cmd2 = """ +RSLP_PREFIX=gs://rslearn-eai \ + python -m rslp.main helios launch_finetune \ + --helios_checkpoint_path /weka/dfive-default/helios/checkpoints/ryanp/{checkpoint} \ + --patch_size {patch_size} \ + --encoder_embedding_size {embedding_size} \ + --image_name {image_name} \ + --config_paths+=data/helios/{benchmark}/{cfg}.yaml \ + --cluster+=ai2/ceres-cirrascale \ + --cluster+=ai2/saturn-cirrascale \ + --rslp_project {rslp_project} \ + --experiment_id {experiment_id} \ + --priority {priority} +""" + +char_limit = 90 # beaker strings can't be >128 chars +homepath = "/weka/dfive-default/ryanp/rslearn_projects" +priority = "high" +rslp_project = "helios_cross_finetuning_v4" +patch_size = 8 +embedding_size = 768 +image_name = "favyen/rslphelios3" +base_path = "/weka/dfive-default/ryanp/rslearn_projects/data/helios" +models = [ + "v2_crop_type_classification_helios_base_S2_ts_ws8_ps1", + "v2_landsat_vessel_classification_helios_base_ps4_add_prob_threshold", + "v2_landsat_vessel_detection_helios_base_ps4", + "v2_cropland_classification_helios_base_S2_ts_ws8_ps8", + "v2_base", +] +benchmarks = { + "v2_nandi_crop_type": ["finetune_s2"], + "v2_worldcereal_cropland": ["finetune_s2"], + "v2_landsat_vessels": ["finetune_classifier", "finetune_detector"], +} + +commands = [] +for benchmark, info in benchmarks.items(): + for model in models: + if not model.startswith("v2_") and not model.startswith("base"): + continue + if info is not None: + cmd_template = cmd2 + for cfg in info: + cmd = cmd_template.format( + checkpoint=model, + patch_size=patch_size, + embedding_size=embedding_size, + image_name=image_name, + benchmark=benchmark, + rslp_project=rslp_project, + experiment_id=f"{benchmark}_{cfg}__CROSS__{model}"[:char_limit], + priority=priority, + cfg=cfg, + ) + commands.append(cmd.strip()) + else: + cmd_template = cmd1 + cmd = cmd_template.format( + checkpoint=model, + patch_size=patch_size, + embedding_size=embedding_size, + image_name=image_name, + benchmark=benchmark, + rslp_project=rslp_project, + experiment_id=f"{benchmark}__CROSS__{model}"[:char_limit], + priority=priority, + ) + commands.append(cmd.strip()) + +print("=" * 80) +for cmd in commands: + print(cmd) + print("\n\n") +print("=" * 80) + +os.chdir(homepath) + +print("Number of models:", len(models)) +print("Number of benchmarks:", len(benchmarks)) +print("Number of commands:", len(commands)) +print("=" * 80) +input("Press Enter to continue...") + + +def run_command(cmd: str) -> bool: + """Run a shell command and return success status. + + Args: + cmd: The shell command to execute. + + Returns: + True if the command succeeded, False otherwise. + """ + print(f"Running: {cmd}\n\n") + try: + result = subprocess.run(cmd, shell=True) # nosec + if result.returncode == 0: + return True + else: + return False + except Exception as e: + print(f"Exception: {e}") + return False + + +max_workers = min(32, len(commands)) +with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map(run_command, commands)) + +successful = sum(results) +failed = len(results) - successful +print(f"\nSummary: {successful} successful, {failed} failed") diff --git a/rslp/helios/scripts/make_multidataset_config.py b/rslp/helios/scripts/make_multidataset_config.py new file mode 100644 index 00000000..61bc382b --- /dev/null +++ b/rslp/helios/scripts/make_multidataset_config.py @@ -0,0 +1,192 @@ +"""Create multi-dataset configuration files. + +This script merges multiple dataset configurations into a single multi-dataset +configuration file for training models on multiple tasks simultaneously. +""" + +import argparse +import json +import os +import tempfile +from copy import deepcopy +from ntpath import basename +from typing import Any + +import yaml + + +def apply_template(config_str: str, cfg: dict[str, Any]) -> str: + """Apply template substitutions to a configuration string. + + Args: + config_str: The configuration string with template placeholders. + cfg: Dictionary containing values to substitute. + + Returns: + Configuration string with placeholders replaced. + """ + config_str = config_str.replace("{CHECKPOINT_PATH}", cfg["helios_checkpoint_path"]) + config_str = config_str.replace("{PATCH_SIZE}", str(cfg["patch_size"])) + config_str = config_str.replace("{256/PATCH_SIZE}", str(256 // cfg["patch_size"])) + config_str = config_str.replace("{128/PATCH_SIZE}", str(128 // cfg["patch_size"])) + config_str = config_str.replace( + "{ENCODER_EMBEDDING_SIZE}", str(cfg["encoder_embedding_size"]) + ) + return config_str + + +def deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + """Deep merge two dictionaries, handling list extensions with '+' suffix. + + Args: + base: Base dictionary to merge into. + override: Dictionary with values to merge. + + Returns: + Merged dictionary. + """ + for k, v in override.items(): + v_copy = v.copy() if hasattr(v, "copy") else v + if k.endswith("+"): + k = k[:-1] + if k not in base: + base[k] = [] + base[k].extend(v_copy) + else: + if isinstance(v, dict): + base[k] = deep_merge(base.get(k, {}), v_copy) + else: + base[k] = v_copy + return base + + +def merge_configs(cfg_list: list[str], maker_cfg: dict[str, Any]) -> str: + """Merge multiple configuration files into a single YAML string. + + Args: + cfg_list: List of configuration file paths to merge. + maker_cfg: Configuration dictionary for template substitution. + + Returns: + YAML string containing merged configuration. + """ + dicts = [] + for cfg in cfg_list: + with open(cfg) as f: + cfg_str = apply_template(f.read(), maker_cfg) + dicts.append(yaml.safe_load(cfg_str)) + merged_dict = dicts[0].copy() + for d in dicts[1:]: + merged_dict = deep_merge(merged_dict, d) + return yaml.dump(merged_dict) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cfg", type=str, required=True, help="Path to multi-dataset maker config" + ) + args = parser.parse_args() + + with open(args.cfg) as f: + maker_cfg: dict[str, Any] = yaml.safe_load(f) + print(json.dumps(maker_cfg, indent=4)) + print() + print("=" * 80) + print() + + if maker_cfg["output_path"] is None: + s = "" + for task_cfg in maker_cfg["dataset_cfgs"]: + if isinstance(task_cfg, list): + task_cfg = task_cfg[0] + basename = os.path.basename(task_cfg).replace(".yaml", "") + s += f"{os.path.basename(os.path.dirname(task_cfg))}__{basename}__" + maker_cfg["output_path"] = maker_cfg["base_cfg"].replace( + ".yaml", f"__{s[:-2]}.yaml" + ) + + to_tmp = {} + dataset_cfgs_list = maker_cfg["dataset_cfgs"] + for i, cfg in enumerate([maker_cfg["base_cfg"]] + dataset_cfgs_list): + if isinstance(cfg, list): + cfg_key = "__".join(cfg) + to_tmp[cfg_key] = merge_configs(cfg, maker_cfg) + dataset_cfgs_list[i - 1] = cfg_key # type: ignore[index] + else: + with open(cfg) as f: + to_tmp[cfg] = apply_template(f.read(), maker_cfg) + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_dataset_cfgs = { + cfg: os.path.join(tmpdir, f"{os.path.basename(cfg)}") for cfg in to_tmp + } + tmp_task_buffers = [open(fp, "w+") for fp in tmp_dataset_cfgs.values()] + try: + for cfg in to_tmp: + with open(tmp_dataset_cfgs[cfg], "w+") as f: + f.write(to_tmp[cfg]) + f.flush() + + with open(tmp_dataset_cfgs[maker_cfg["base_cfg"]]) as f: + base_cfg = yaml.safe_load(f) + + data_modules = {} + decoders = {} + task = { + "class_path": "rslearn.train.tasks.multi_task.MultiTask", + "init_args": {}, + } + for task_cfg in tmp_dataset_cfgs.values(): + if task_cfg == tmp_dataset_cfgs[maker_cfg["base_cfg"]]: + continue + with open(task_cfg) as f: + task_cfg = yaml.safe_load(f) + subtasks = list( + task_cfg["model"]["init_args"]["model"]["init_args"][ + "decoders" + ].keys() + ) + assert len(subtasks) == 1, "Only one subtask per task is supported" + + task_name = subtasks[0] + data_modules[task_name] = task_cfg["data"] + decoders.update( + task_cfg["model"]["init_args"]["model"]["init_args"]["decoders"] + ) + + if maker_cfg.get("max_train_patches") is not None: + task_cfg["data"]["init_args"]["train_config"]["num_patches"] = ( + maker_cfg["max_train_patches"] + ) + if maker_cfg.get("max_val_patches") is not None: + task_cfg["data"]["init_args"]["val_config"]["num_patches"] = ( + maker_cfg["max_val_patches"] + ) + if maker_cfg.get("batch_size") is not None: + task_cfg["data"]["init_args"]["batch_size"] = maker_cfg[ + "batch_size" + ] + + for k, v in task_cfg["data"]["init_args"]["task"][ + "init_args" + ].items(): + try: + task["init_args"][k].update(v.copy()) # type: ignore + except KeyError: + task["init_args"][k] = v.copy() # type: ignore + + if maker_cfg.get("num_workers") is not None: + base_cfg["data"]["init_args"]["num_workers"] = maker_cfg["num_workers"] + base_cfg["data"]["init_args"]["data_modules"] = data_modules + base_cfg["model"]["init_args"]["model"]["init_args"]["decoders"] = decoders + base_cfg["model"]["init_args"]["task"] = deepcopy(task) + + with open(maker_cfg["output_path"], "w") as f: # type: ignore + yaml.dump(base_cfg, f) + + print(json.dumps(base_cfg, indent=4)) + + finally: + for f in tmp_task_buffers: + f.close() diff --git a/rslp/launcher_lib.py b/rslp/launcher_lib.py index 6572506e..9a6d1cbb 100644 --- a/rslp/launcher_lib.py +++ b/rslp/launcher_lib.py @@ -23,6 +23,7 @@ "lightning_logs", "test_data", "wandb", + "project_data", ] logger = get_logger(__name__) diff --git a/rslp/lightning_cli.py b/rslp/lightning_cli.py index 8a5ca5fb..acf7355f 100644 --- a/rslp/lightning_cli.py +++ b/rslp/lightning_cli.py @@ -9,14 +9,13 @@ import fsspec import jsonargparse +import lightning as L import wandb from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback from lightning.pytorch.cli import SaveConfigCallback from lightning.pytorch.utilities import rank_zero_only from rslearn.main import RslearnLightningCLI -from rslearn.train.data_module import RslearnDataModule -from rslearn.train.lightning_module import RslearnLightningModule from upath import UPath import rslp.utils.fs # noqa: F401 (imported but unused) @@ -200,6 +199,12 @@ def add_arguments_to_parser(self, parser: jsonargparse.ArgumentParser) -> None: help="Disable W&B logging for fit", default=False, ) + parser.add_argument( + "--profiler", + type=str, + help="Profiler to use for training. Can be 'simple' or 'advanced'", + default=None, + ) def _get_checkpoint_path( self, checkpoint_dir: UPath, load_best: bool = False, autoresume: bool = False @@ -331,6 +336,26 @@ def before_instantiate_classes(self) -> None: } ) + # Configure profiler if specified + if c.profiler: + max_steps = 100 + c.trainer.profiler = c.profiler + c.trainer.max_steps = max_steps + logger.info(f"Using profiler: {c.profiler}") + logger.info(f"Setting max_steps to {max_steps}") + + # If we are using multi dataset, we have a custom batch sampler + # and so don't need lightning to wrap it for us + if c.data.class_path == "rslearn.train.data_module.MultiDatasetDataModule": + # Usually, it's fine to leave this flag on (lightning will detect + # a distributed sampler and leave it alone), but since we have a custom + # one not subclassing DistributedSampler, we need to turn it off manually + logger.info("Using custom distributed sampler") + logger.info( + "Warnings about calling compute on an empty set of metrics may appear" + ) + c.trainer.use_distributed_sampler = False + if subcommand == "fit" and not c.no_log: # Set the checkpoint directory to canonical GCS location. checkpoint_callback = None @@ -403,9 +428,11 @@ def custom_model_handler() -> None: It also sets the save_config_callback. """ + # Decreased strictness of type checking for model and datamodule classes + # to allow for multiple dataset training tasks CustomLightningCLI( - model_class=RslearnLightningModule, - datamodule_class=RslearnDataModule, + model_class=L.LightningModule, + datamodule_class=L.LightningDataModule, args=sys.argv[2:], subclass_mode_model=True, subclass_mode_data=True,