Skip to content

Commit 7205b37

Browse files
authored
Revert to a default GaussianProcessRansac which always accepts (#179)
the candidate (instead of doing a chi squared acceptance test)
1 parent 06b5b18 commit 7205b37

File tree

3 files changed

+55
-27
lines changed

3 files changed

+55
-27
lines changed

include/albatross/src/core/declarations.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,11 @@ template <typename InlierMetric, typename ConsensusMetric,
133133
typename IndexingFunction>
134134
struct GenericRansacStrategy;
135135

136+
struct AlwaysAcceptCandidateMetric;
137+
136138
template <typename InlierMetric, typename ConsensusMetric,
137-
typename IndexingFunction>
139+
typename IndexingFunction,
140+
typename IsValidCandidateMetric = AlwaysAcceptCandidateMetric>
138141
struct GaussianProcessRansacStrategy;
139142

140143
/*

include/albatross/src/models/ransac_gp.hpp

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,18 @@ inline
4141
};
4242
}
4343

44-
template <typename ModelType, typename FeatureType>
44+
template <typename ModelType, typename FeatureType,
45+
typename IsValidCandidateMetric>
4546
inline typename RansacFunctions<
4647
FitAndIndices<ModelType, FeatureType>>::IsValidCandidate
4748
get_gp_ransac_is_valid_candidate(const RegressionDataset<FeatureType> &dataset,
4849
const FoldIndexer &indexer,
49-
const Eigen::MatrixXd &cov) {
50-
return [&, indexer, cov, dataset](const std::vector<FoldName> &groups) {
51-
auto inds = indices_from_names(indexer, groups);
52-
const auto train_dataset = subset(dataset, inds);
53-
const auto train_cov = symmetric_subset(cov, inds);
50+
const Eigen::MatrixXd &cov,
51+
const IsValidCandidateMetric &metric) {
5452

55-
const JointDistribution prior(Eigen::VectorXd::Zero(train_cov.rows()),
56-
train_cov);
57-
// These thresholds are under the assumption of a perfectly
58-
// representative prior.
59-
const double probability_prior_exceeded =
60-
chi_squared_cdf(prior, train_dataset.targets);
61-
const double skip_every_1000th_candidate = 0.999;
62-
return (probability_prior_exceeded < skip_every_1000th_candidate);
53+
return [&, indexer, cov, dataset](const std::vector<FoldName> &groups) {
54+
const auto inds = indices_from_names(indexer, groups);
55+
return metric(inds, dataset, cov);
6356
};
6457
}
6558

@@ -145,14 +138,43 @@ class ChiSquaredConsensusMetric {
145138
Eigen::MatrixXd cov_;
146139
};
147140

141+
struct ChiSquaredIsValidCandidateMetric {
142+
143+
template <typename FeatureType>
144+
bool operator()(const FoldIndices &inds,
145+
const RegressionDataset<FeatureType> &dataset,
146+
const Eigen::MatrixXd &cov) const {
147+
const auto train_dataset = subset(dataset, inds);
148+
const auto train_cov = symmetric_subset(cov, inds);
149+
150+
const JointDistribution prior(Eigen::VectorXd::Zero(train_cov.rows()),
151+
train_cov);
152+
// These thresholds are under the assumption of a perfectly
153+
// representative prior.
154+
const double probability_prior_exceeded =
155+
chi_squared_cdf(prior, train_dataset.targets);
156+
const double skip_every_1000th_candidate = 0.999;
157+
return (probability_prior_exceeded < skip_every_1000th_candidate);
158+
};
159+
};
160+
161+
struct AlwaysAcceptCandidateMetric {
162+
template <typename FeatureType>
163+
bool operator()(const FoldIndices &inds,
164+
const RegressionDataset<FeatureType> &dataset,
165+
const Eigen::MatrixXd &cov) const {
166+
return true;
167+
}
168+
};
169+
148170
template <typename ModelType, typename FeatureType, typename InlierMetric,
149-
typename ConsensusMetric>
171+
typename ConsensusMetric, typename IsValidCandidateMetric>
150172
inline RansacFunctions<FitAndIndices<ModelType, FeatureType>>
151-
get_gp_ransac_functions(const ModelType &model,
152-
const RegressionDataset<FeatureType> &dataset,
153-
const FoldIndexer &indexer,
154-
const InlierMetric &inlier_metric,
155-
const ConsensusMetric &consensus_metric) {
173+
get_gp_ransac_functions(
174+
const ModelType &model, const RegressionDataset<FeatureType> &dataset,
175+
const FoldIndexer &indexer, const InlierMetric &inlier_metric,
176+
const ConsensusMetric &consensus_metric,
177+
const IsValidCandidateMetric &is_valid_candidate_metric) {
156178

157179
static_assert(is_prediction_metric<InlierMetric>::value,
158180
"InlierMetric must be an PredictionMetric.");
@@ -170,16 +192,16 @@ get_gp_ransac_functions(const ModelType &model,
170192
full_cov);
171193

172194
const auto is_valid_candidate =
173-
get_gp_ransac_is_valid_candidate<ModelType, FeatureType>(dataset, indexer,
174-
full_cov);
195+
get_gp_ransac_is_valid_candidate<ModelType, FeatureType>(
196+
dataset, indexer, full_cov, is_valid_candidate_metric);
175197

176198
return RansacFunctions<FitAndIndices<ModelType, FeatureType>>(
177199
fitter, inlier_metric_from_group, consensus_metric_from_group,
178200
is_valid_candidate);
179201
};
180202

181203
template <typename InlierMetric, typename ConsensusMetric,
182-
typename IndexingFunction>
204+
typename IndexingFunction, typename IsValidCandidateMetric>
183205
struct GaussianProcessRansacStrategy {
184206

185207
GaussianProcessRansacStrategy() = default;
@@ -188,15 +210,15 @@ struct GaussianProcessRansacStrategy {
188210
const ConsensusMetric &consensus_metric,
189211
const IndexingFunction &indexing_function)
190212
: inlier_metric_(inlier_metric), consensus_metric_(consensus_metric),
191-
indexing_function_(indexing_function){};
213+
indexing_function_(indexing_function), is_valid_candidate_(){};
192214

193215
template <typename ModelType, typename FeatureType>
194216
RansacFunctions<FitAndIndices<ModelType, FeatureType>>
195217
operator()(const ModelType &model,
196218
const RegressionDataset<FeatureType> &dataset) const {
197219
const auto indexer = get_indexer(dataset);
198220
return get_gp_ransac_functions(model, dataset, indexer, inlier_metric_,
199-
consensus_metric_);
221+
consensus_metric_, is_valid_candidate_);
200222
}
201223

202224
template <typename FeatureType>
@@ -208,6 +230,7 @@ struct GaussianProcessRansacStrategy {
208230
InlierMetric inlier_metric_;
209231
ConsensusMetric consensus_metric_;
210232
IndexingFunction indexing_function_;
233+
IsValidCandidateMetric is_valid_candidate_;
211234
};
212235

213236
using DefaultGPRansacStrategy =

tests/test_models.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,10 @@ struct AdaptedRansacStrategy : public GaussianProcessRansacStrategy<
164164
adapted::convert_features(dataset.features), dataset.targets);
165165
const auto indexer = get_indexer(converted);
166166
const FeatureCountConsensusMetric consensus_metric;
167+
const AlwaysAcceptCandidateMetric always_accept;
167168
return get_gp_ransac_functions(model, converted, indexer,
168-
this->inlier_metric_, consensus_metric);
169+
this->inlier_metric_, consensus_metric,
170+
always_accept);
169171
}
170172
};
171173

0 commit comments

Comments
 (0)