Skip to content

Commit 71d0c23

Browse files
authored
feat(cellguide): improve computational marker gene algorithm (#6003)
1 parent 845d776 commit 71d0c23

File tree

11 files changed

+497
-347
lines changed

11 files changed

+497
-347
lines changed

backend/cellguide/pipeline/computational_marker_genes/__init__.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,29 @@ def get_computational_marker_genes(*, snapshot: WmgSnapshot, ontology_tree: Onto
6161
else:
6262
marker_genes[key] = marker_genes_per_tissue[key]
6363

64+
# convert all groupby_dims IDs to labels as required by CellGuide
65+
organism_id_to_name = {k: v for d in snapshot.primary_filter_dimensions["organism_terms"] for k, v in d.items()}
66+
tissue_id_to_name = {
67+
k: v
68+
for organism in snapshot.primary_filter_dimensions["tissue_terms"]
69+
for i in snapshot.primary_filter_dimensions["tissue_terms"][organism]
70+
for k, v in i.items()
71+
}
72+
for _, marker_gene_stats_list in marker_genes.items():
73+
for marker_gene_stats in marker_gene_stats_list:
74+
groupby_dims = marker_gene_stats.groupby_dims
75+
groupby_terms = list(groupby_dims.keys())
76+
groupby_term_labels = [term.rsplit("_", 1)[0] + "_label" for term in groupby_terms]
77+
groupby_dims_new = dict(zip(groupby_term_labels, (groupby_dims[term] for term in groupby_terms)))
78+
79+
for key in groupby_dims_new:
80+
if key == "tissue_ontology_term_label":
81+
groupby_dims_new[key] = tissue_id_to_name.get(groupby_dims_new[key], groupby_dims_new[key])
82+
elif key == "organism_ontology_term_label":
83+
groupby_dims_new[key] = organism_id_to_name.get(groupby_dims_new[key], groupby_dims_new[key])
84+
85+
marker_gene_stats.groupby_dims = groupby_dims_new
86+
6487
reformatted_marker_genes = {}
6588
for cell_type_id, marker_gene_stats_list in marker_genes.items():
6689
for marker_gene_stats in marker_gene_stats_list:
@@ -87,11 +110,11 @@ def get_computational_marker_genes(*, snapshot: WmgSnapshot, ontology_tree: Onto
87110
)
88111
reformatted_marker_genes[symbol][organism][tissue].append(data)
89112

90-
# assert that cell types do not appear multiple times in each gene, tissue, organism
91-
for symbol in reformatted_marker_genes:
92-
for organism in reformatted_marker_genes[symbol]:
93-
for tissue in reformatted_marker_genes[symbol][organism]:
94-
cell_type_ids = [i["cell_type_id"] for i in reformatted_marker_genes[symbol][organism][tissue]]
95-
assert len(cell_type_ids) == len(list(set(cell_type_ids)))
113+
# # assert that cell types do not appear multiple times in each gene, tissue, organism
114+
# for symbol in reformatted_marker_genes:
115+
# for organism in reformatted_marker_genes[symbol]:
116+
# for tissue in reformatted_marker_genes[symbol][organism]:
117+
# cell_type_ids = [i["cell_type_id"] for i in reformatted_marker_genes[symbol][organism][tissue]]
118+
# assert len(cell_type_ids) == len(list(set(cell_type_ids)))
96119

97120
return marker_genes, reformatted_marker_genes

backend/cellguide/pipeline/computational_marker_genes/computational_markers.py

Lines changed: 221 additions & 139 deletions
Large diffs are not rendered by default.

backend/cellguide/pipeline/computational_marker_genes/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ class ComputationalMarkerGenes:
66
me: float
77
pc: float
88
marker_score: float
9+
specificity: float
10+
gene_ontology_term_id: str
911
symbol: str
1012
name: str
1113
groupby_dims: dict[str, str]

backend/cellguide/pipeline/computational_marker_genes/utils.py

Lines changed: 82 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -2,72 +2,34 @@
22
from typing import Tuple
33

44
import numpy as np
5-
from numba import njit
6-
from scipy import stats
5+
from numba import njit, prange
76

87
from backend.common.constants import DEPLOYMENT_STAGE_TO_API_URL
9-
from backend.common.utils.rollup import are_cell_types_colinear
108
from backend.wmg.data.utils import setup_retry_session
119

1210

13-
@njit
14-
def nanpercentile_2d(arr: np.ndarray, percentile: float, axis: int) -> np.ndarray:
15-
"""
16-
Calculate the specified percentile of a 2D array along an axis, ignoring NaN values.
17-
18-
Arguments
19-
---------
20-
arr - 2D array to calculate percentile of
21-
percentile - percentile to calculate, as a number between 0 and 100
22-
axis - axis along which to calculate percentile
23-
24-
Returns
25-
-------
26-
The specified percentile of the 2D array along the specified axis.
27-
"""
28-
if axis == 0:
29-
result = np.empty(arr.shape[1])
30-
for i in range(arr.shape[1]):
31-
arr_column = arr[:, i]
32-
result[i] = nanpercentile(arr_column, percentile)
33-
return result
34-
else:
35-
result = np.empty(arr.shape[0])
36-
for i in range(arr.shape[0]):
37-
arr_row = arr[i, :]
38-
result[i] = nanpercentile(arr_row, percentile)
39-
return result
40-
41-
42-
@njit
43-
def nanpercentile(arr: np.ndarray, percentile: float):
44-
"""
45-
Calculate the specified percentile of an array, ignoring NaN values.
46-
47-
Arguments
48-
---------
49-
arr - array to calculate percentile of
50-
percentile - percentile to calculate, as a number between 0 and 100
51-
52-
Returns
53-
-------
54-
The specified percentile of the array.
55-
"""
56-
57-
arr_without_nan = arr[np.logical_not(np.isnan(arr))]
58-
length = len(arr_without_nan)
11+
@njit(parallel=True)
12+
def calculate_specificity_excluding_nans(treatment, control):
13+
treatment = treatment.flatten()
5914

60-
if length == 0:
61-
return np.nan
15+
specificities = np.zeros(treatment.size)
16+
for i in prange(treatment.size):
17+
if np.isnan(treatment[i]):
18+
continue
19+
col = control[:, i]
20+
col = col[~np.isnan(col)]
21+
if col.size == 0:
22+
specificities[i] = 1
23+
else:
24+
specificities[i] = (treatment[i] > col).mean()
25+
return specificities
6226

63-
return np.percentile(arr_without_nan, percentile)
6427

65-
66-
def run_ttest(
28+
def calculate_cohens_d(
6729
*, sum1: np.ndarray, sumsq1: np.ndarray, n1: np.ndarray, sum2: np.ndarray, sumsq2: np.ndarray, n2: np.ndarray
6830
) -> Tuple[np.ndarray, np.ndarray]:
6931
"""
70-
Run a t-test on two sets of data, element-wise.
32+
Calculates Cohen's d for two sets of data.
7133
Arrays "1" and "2" have to be broadcastable into each other.
7234
7335
Arguments
@@ -96,73 +58,9 @@ def run_ttest(
9658
var1[var1 < 0] = 0
9759
var2 = meansq2 - mean2**2
9860
var2[var2 < 0] = 0
99-
100-
var1_n = var1 / n1
101-
var2_n = var2 / n2
102-
sum_var_n = var1_n + var2_n
103-
dof = sum_var_n**2 / (var1_n**2 / (n1 - 1) + var2_n**2 / (n2 - 1))
104-
tscores = (mean1 - mean2) / np.sqrt(sum_var_n)
10561
effects = (mean1 - mean2) / np.sqrt(((n1 - 1) * var1 + (n2 - 1) * var2) / (n1 + n2 - 1))
10662

107-
pvals = stats.t.sf(tscores, dof)
108-
return pvals, effects
109-
110-
111-
def post_process_stats(
112-
*,
113-
cell_type_target: str,
114-
cell_types_context: np.ndarray,
115-
genes: np.ndarray,
116-
pvals: np.ndarray,
117-
effects: np.ndarray,
118-
percentile: float = 0.05,
119-
) -> dict[str, dict[str, float]]:
120-
"""
121-
Post-process the statistical results to handle colinearity of cell types in the ontology and calculate percentiles.
122-
123-
Arguments
124-
---------
125-
cell_type_target - The target cell type
126-
cell_types_context - The context cell types
127-
genes - The genes involved in the analysis
128-
pvals - The p-values from the statistical test
129-
effects - The effect sizes from the statistical test
130-
percentile - The percentile to use for thresholding (default is 0.05)
131-
132-
Returns
133-
-------
134-
A dictionary mapping marker genes to their statistics.
135-
"""
136-
137-
# parent nodes msut not be compared to their children because they share expressions,
138-
# since the expressions are rolled up across descendants
139-
is_colinear = np.array([are_cell_types_colinear(cell_type, cell_type_target) for cell_type in cell_types_context])
140-
effects[is_colinear] = np.nan
141-
pvals[is_colinear] = np.nan
142-
143-
pvals[:, np.all(np.isnan(pvals), axis=0)] = 1
144-
effects[:, np.all(np.isnan(effects), axis=0)] = 0
145-
146-
# aggregate
147-
effects = nanpercentile_2d(effects, percentile * 100, 0)
148-
149-
effects[effects == 0] = np.nan
150-
151-
pvals = np.sort(pvals, axis=0)[int(np.round(0.05 * pvals.shape[0]))]
152-
153-
markers = np.array(genes)[np.argsort(-effects)]
154-
p = pvals[np.argsort(-effects)]
155-
effects = effects[np.argsort(-effects)]
156-
157-
statistics = []
158-
final_markers = []
159-
for i in range(len(p)):
160-
pi = p[i]
161-
ei = effects[i]
162-
if ei is not np.nan and pi is not np.nan:
163-
statistics.append({"p_value": pi, "effect_size": ei})
164-
final_markers.append(markers[i])
165-
return dict(zip(list(final_markers), statistics))
63+
return effects
16664

16765

16866
def query_gene_info_for_gene_description(gene_id: str) -> str:
@@ -193,3 +91,67 @@ def query_gene_info_for_gene_description(gene_id: str) -> str:
19391
return data["name"]
19492
else:
19593
return gene_id
94+
95+
96+
@njit(parallel=True)
97+
def bootstrap_rows_percentiles(
98+
X: np.ndarray, random_indices: np.ndarray, num_replicates: int = 1000, num_samples: int = 100, percentile: float = 5
99+
):
100+
"""
101+
This function bootstraps rows of a given matrix X.
102+
103+
Arguments
104+
---------
105+
X : np.ndarray
106+
The input matrix to bootstrap.
107+
num_replicates : int, optional
108+
The number of bootstrap replicates to generate, by default 1000.
109+
num_samples : int, optional
110+
The number of samples to draw in each bootstrap replicate, by default 100.
111+
percentile : float, optional
112+
The percentile of the bootstrapped samples for each replicate, by default 15.
113+
114+
Returns
115+
-------
116+
bootstrap_percentile : np.ndarray
117+
The percentile of the bootstrapped samples for each replicate.
118+
"""
119+
120+
bootstrap_percentile = np.zeros((num_replicates, X.shape[1]), dtype="float")
121+
# for each replicate
122+
for n_i in prange(num_replicates):
123+
bootstrap_percentile[n_i] = sort_matrix_columns(X[random_indices[n_i]], percentile, num_samples)
124+
125+
return bootstrap_percentile
126+
127+
128+
@njit
129+
def sort_matrix_columns(matrix, percentile, num_samples):
130+
"""
131+
This function sorts the columns of a given matrix and returns the index associated with
132+
the specified percentile of the sorted samples for each column. This approximates
133+
np.nanpercentile(matrix, percentile, axis=0).
134+
135+
Arguments
136+
---------
137+
matrix : np.ndarray
138+
The input matrix to sort.
139+
percentile : float
140+
The percentile of the sorted samples for each column.
141+
num_samples : int
142+
The number of samples in each column.
143+
144+
Returns
145+
-------
146+
result : np.ndarray
147+
The sorted columns of the input matrix.
148+
"""
149+
num_cols = matrix.shape[1]
150+
result = np.empty(num_cols)
151+
for col in range(num_cols):
152+
sorted_col = np.sort(matrix[:, col])
153+
num_nans = np.isnan(sorted_col).sum()
154+
num_non_nans = num_samples - num_nans
155+
sample_index = int(np.round(percentile / 100 * num_non_nans))
156+
result[col] = sorted_col[sample_index]
157+
return result

backend/cellguide/pipeline/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@
2525
# 24 CPUs was chosen to balance memory usage and speed on a c6i.32xlarge EC2 machine
2626
# In trial runs, the memory usage did not exceed 50% of the available memory, which provides
2727
# ample buffer.
28-
CELLGUIDE_PIPELINE_NUM_CPUS = min(os.cpu_count(), os.getenv("CELLGUIDE_PIPELINE_NUM_CPUS", 12))
28+
CELLGUIDE_PIPELINE_NUM_CPUS = min(os.cpu_count(), os.getenv("CELLGUIDE_PIPELINE_NUM_CPUS", 24))
2929

3030
CELL_GUIDE_DATA_BUCKET_PATH_PREFIX = "s3://cellguide-data-public-"

backend/common/utils/rollup.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def are_cell_types_colinear(cell_type1, cell_type2):
8181
return len(set(descendants1).intersection(ancestors2)) > 0 or len(set(descendants2).intersection(ancestors1)) > 0
8282

8383

84+
def get_overlapping_cell_type_descendants(cell_type1, cell_type2):
85+
"""
86+
Get overlapping cell type descendants
87+
88+
Arguments
89+
---------
90+
cell_type1 : str
91+
Cell type 1 (cell type ontology term id)
92+
cell_type2 : str
93+
Cell type 2 (cell type ontology term id)
94+
Returns
95+
-------
96+
list[str]
97+
"""
98+
descendants1 = descendants(cell_type1)
99+
descendants2 = descendants(cell_type2)
100+
101+
return list(set(descendants1).intersection(descendants2))
102+
103+
84104
def rollup_across_cell_type_descendants(
85105
df, cell_type_col="cell_type_ontology_term_id", parallel=True, ignore_cols=None
86106
) -> pd.DataFrame:

0 commit comments

Comments
 (0)