-
Notifications
You must be signed in to change notification settings - Fork 167
Description
Describe the bug:
The predict_proba method in SksurvAdapter (used by estimators like CoxPHSkSurv) contains three mathematical errors that result in improper probability distributions:
1)Initial Mass Loss: The adapter fails to capture the probability of events occurring between t=0 and the first observed event time t_1. It calculates masses using np.diff but does not "prepend" the initial survival state of 1.0 correctly, effectively discarding the first drop in survival.
2)Tail Mass Loss: The adapter ignores the remaining survival probability at the end of the observed timeline. If S(t{last}) > 0, that probability mass is simply discarded, resulting in an Empirical distribution whose weights sum to significantly less than 1.0.
3)Temporal Alignment Error: When calculating the mass between t_i and t{i+1}, the adapter assigns the weight to the timestamp t_i. In survival analysis, the drop in survival probability at t_{i+1} represents events occurring at that time; assigning it to t_i shifts the distribution backwards, making the model predict that events happen earlier than they do.
To Reproduce:
Install dependencies:
!pip install skproRun the following reproduction script:
import numpy as np
import pandas as pd
from skpro.survival.adapters.sksurv import _SksurvAdapter
from unittest.mock import MagicMock
class MockSksurvAdapter(_SksurvAdapter):
def _get_sksurv_class(self):
return MagicMock()
def get_params(self, deep=True):
return {}
# Define mock survival results: S(10)=0.8, S(20)=0.5, S(30)=0.5
mock_surv = np.array([[0.8, 0.5, 0.5]])
mock_times = np.array([10.0, 20.0, 30.0])
X = pd.DataFrame({"feature1": [1.0]})
adapter = MockSksurvAdapter()
adapter._estimator = MagicMock()
adapter._estimator.predict_survival_function = MagicMock(return_value=mock_surv)
adapter._estimator.unique_times_ = mock_times
adapter._y_cols = ["time"]
dist = adapter._predict_proba(X)
print(f"Total mass: {dist.weights.sum()}")
print(f"Times in Empirical dist: {dist.spl.values.flatten()}")
print(f"Weights in Empirical dist: {dist.weights.values}")
if dist.weights.sum() < 1.0:
print("\nBUG CONFIRMED: Total mass < 1.0")Expected behavior:
- The total weight of the predicted Empirical distribution should always be 1.0.
- Probability mass should be correctly aligned with the timestamps of the survival drops.
- The mass at
$t=0$ or the first event time$t_1$ should be preserved.
Actual Output:
Environment:
OS: Windows
Python: 3.14.2
skpro: 2.11.0
Additional context:
Affected file: skpro/survival/adapters/sksurv.py
In _predict_proba, the line weights = -np.diff(sksurv_survf, axis=1).flatten() only looks at internal differences and ignores the boundaries.
Proposed Fix:
The adapter should use logic similar to _surv_diff in _common.py which prepends 1.0 and appends 0.0 (or handles the remainder at infinity) to ensure the total mass sums to 1.0 and is correctly aligned with the event times.
If verified, I would like to open a PR on these fixes.