Skip to content

Commit d167af2

Browse files
authored
Merge pull request #46 from chanzuckerberg/tomo-fix
fix: Make sure AMG Parameters are properly read for Tomography Workflow
2 parents b3bdd7a + 92fa280 commit d167af2

File tree

5 files changed

+11
-10
lines changed

5 files changed

+11
-10
lines changed

docs/tutorials/inference.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ saber segment tomograms \
7474
--config copick_config.json \
7575
--model-config results/model_config.yaml \
7676
--model-weights results/best_model.pth \
77-
--segmentation-name organelles \
77+
--seg-name organelles \
7878
--target-class 2
7979
```
8080

81-
When no `--run-ids` are provided, SABER will segment the entire project and save the results under provided `--segmentation-name` and `--segmentation-session-id` flags (by default, the `user-id` will always be saber. )
81+
When no `--run-ids` are provided, SABER will segment the entire project and save the results under provided `--seg-name` and `--seg-session-id` flags (by default, the `user-id` will always be saber. )
8282

8383
**Parameters**:
8484

@@ -88,8 +88,8 @@ When no `--run-ids` are provided, SABER will segment the entire project and save
8888
| `--voxel-size` | Resolution of tomograms to process | `10` |
8989
| `--tomo-alg` | Reconstruction algorithm to query | `denoised` |
9090
| `--slab-thickness` | Thickness of slab for initial segmentation | `10` |
91-
| `--segmentation-name` | Name of segmentation session | `organelles` |
92-
| `--segmentation-session-id` | Session ID for segmentation mask | `1` |
91+
| `--seg-name` | Name of segmentated object | `organelles` |
92+
| `--seg-session-id` | Session ID for segmentation mask | `1` |
9393
| `--model-config` | Path to classifier model config | `results/model_config.yaml` |
9494
| `--model-weights` | Path to trained classifier weights | `results/best_model.pth` |
9595
| `--target-class` | Target class (-1 for semantic, N>0 for instance) | `2` |

saber/classifier/preprocess/tomogram_training_prep.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def extract_sam2_candidates(
4646

4747
# Get the Global Zarr Writer
4848
zwriter = zarr_writer.get_zarr_writer(output)
49-
zwriter.set_dict_attr('amg', segmenter.amg_params)
49+
zwriter.set_dict_attr('amg', segmenter.cfg)
5050

5151
# Get Tomogram
5252
vol = readers.tomogram(run, voxel_size, tomogram_algorithm)
@@ -55,6 +55,7 @@ def extract_sam2_candidates(
5555
return
5656

5757
# Hard coded conversion from Angstroms to nanometers
58+
# Copick tomograms are typically stored in Angstroms
5859
voxel_size /= 10
5960

6061
# Process Multiple Slabs or Single Slab at the Center of the Volume

saber/entry_points/inference_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def segment_micrograph_core(
102102

103103
# Get the Global Zarr Writer
104104
zwriter = zarr_writer.get_zarr_writer(output)
105-
zwriter.set_dict_attr('amg', segmenter.amg_params)
105+
zwriter.set_dict_attr('amg', segmenter.cfg)
106106

107107
# Ensure we're on the correct GPU
108108
torch.cuda.set_device(gpu_id)

saber/entry_points/run_micrograph_segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def micrographs(
5151
if target_resolution is not None and scale_factor is not None:
5252
raise ValueError("Please provide either target_resolution OR scale_factor input, not both.")
5353

54-
# Get All Files in the Directory
54+
# Get All Files in the Directoryclear
5555
files = glob.glob(input)
5656
if len(files) == 0:
5757
raise ValueError(f"No files found in {input}")

saber/segmenters/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self,
6060

6161
# Use Default AMG Config
6262
self.cfg = amg.get_default()
63-
self.cfg['cfg'] = sam2_cfg
63+
self.cfg['sam2_cfg'] = sam2_cfg
6464

6565
# Build SAM2 Automatic Mask Generator
6666
self.mask_generator = amg.build_amg(
@@ -238,9 +238,9 @@ def __init__(self,
238238
super().__init__(sam2_cfg, deviceID, classifier, target_class, min_mask_area, min_rel_box_size)
239239

240240
# Build Tomogram Predictor (VOS Optimized)
241-
(cfg, checkpoint) = pretrained_weights.get_sam2_checkpoint(sam2_cfg)
241+
(cfg, checkpoint) = pretrained_weights.get_sam2_checkpoint(self.cfg['sam2_cfg'])
242242
self.video_predictor = tomogram_predictor.TomogramSAM2Adapter(cfg, checkpoint, self.device)
243-
243+
244244
# Initialize Inference State
245245
self.inference_state = None
246246

0 commit comments

Comments
 (0)