001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.model.fit;
031
032import java.util.ArrayList;
033import java.util.List;
034
035import org.openimaj.math.model.EstimatableModel;
036import org.openimaj.math.model.fit.residuals.ResidualCalculator;
037import org.openimaj.math.util.distance.DistanceCheck;
038import org.openimaj.math.util.distance.ThresholdDistanceCheck;
039import org.openimaj.util.CollectionSampler;
040import org.openimaj.util.UniformSampler;
041import org.openimaj.util.pair.IndependentPair;
042
043/**
044 * The RANSAC Algorithm (RANdom SAmple Consensus)
045 * <p>
046 * For fitting noisy data consisting of inliers and outliers to a model.
047 * </p>
048 * <p>
049 * Assume: M data items required to estimate parameter x N data items in total
050 * </p>
051 * <p>
052 * 1.) select M data items at random <br>
053 * </br> 2.) estimate parameter x <br>
054 * </br> 3.) find how many of the N data items fit (i.e. have an error less than
055 * a threshold or pass some check) x within tolerence tol, call this K <br>
056 * </br> 4.) if K is large enough (bigger than numItems) accept x and exit with
057 * success <br>
058 * </br> 5.) repeat 1..4 nIter times <br>
059 * </br> 6.) fail - no good x fit of data
060 * </p>
061 * <p>
062 * In this implementation, the conditions that control the iterations are
063 * configurable. In addition, the best matching model is always stored, even if
064 * the fitData() method returns false.
065 *
066 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
067 *
068 * @param <I>
069 *            type of independent data
070 * @param <D>
071 *            type of dependent data
072 * @param <M>
073 *            concrete type of model learned
074 */
075public class RANSAC<I, D, M extends EstimatableModel<I, D>> implements RobustModelFitting<I, D, M> {
076        /**
077         * Interface for classes that can control RANSAC iterations
078         */
079        public static interface StoppingCondition {
080                /**
081                 * Initialise the stopping condition if necessary. Return false if the
082                 * initialisation cannot be performed and RANSAC should fail
083                 *
084                 * @param data
085                 *            The data being fitted
086                 * @param model
087                 *            The model to fit
088                 * @return true if initialisation is successful, false otherwise.
089                 */
090                public abstract boolean init(final List<?> data, EstimatableModel<?, ?> model);
091
092                /**
093                 * Should we stop iterating and return the model?
094                 *
095                 * @param numInliers
096                 *            number of inliers in this iteration
097                 * @return true if the model is good and iterations should stop
098                 */
099                public abstract boolean shouldStopIterations(int numInliers);
100
101                /**
102                 * Should the model be considered to fit after the final iteration has
103                 * passed?
104                 *
105                 * @param numInliers
106                 *            number of inliers in the final model
107                 * @return true if the model fits, false otherwise
108                 */
109                public abstract boolean finalFitCondition(int numInliers);
110        }
111
112        /**
113         * Stopping condition that tests the number of matches against a threshold.
114         * If the number exceeds the threshold, then the model is considered to fit.
115         */
116        public static class NumberInliersStoppingCondition implements StoppingCondition {
117                int limit;
118
119                /**
120                 * Construct the stopping condition with the given threshold on the
121                 * number of data points which must match for a model to be considered a
122                 * fit.
123                 *
124                 * @param limit
125                 *            the threshold
126                 */
127                public NumberInliersStoppingCondition(int limit) {
128                        this.limit = limit;
129                }
130
131                @Override
132                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
133                        if (limit < model.numItemsToEstimate()) {
134                                limit = model.numItemsToEstimate();
135                        }
136
137                        if (data.size() < limit)
138                                return false;
139                        return true;
140                }
141
142                @Override
143                public boolean shouldStopIterations(int numInliers) {
144                        return numInliers >= limit; // stop if there are more inliers than
145                        // our limit
146                }
147
148                @Override
149                public boolean finalFitCondition(int numInliers) {
150                        return numInliers >= limit;
151                }
152        }
153
154        /**
155         * Stopping condition that tests the number of matches against a percentage
156         * threshold of the whole data. If the number exceeds the threshold, then
157         * the model is considered to fit.
158         */
159        public static class PercentageInliersStoppingCondition extends NumberInliersStoppingCondition {
160                double percentageLimit;
161
162                /**
163                 * Construct the stopping condition with the given percentage threshold
164                 * on the number of data points which must match for a model to be
165                 * considered a fit.
166                 *
167                 * @param percentageLimit
168                 *            the percentage threshold
169                 */
170                public PercentageInliersStoppingCondition(double percentageLimit) {
171                        super(0);
172                        this.percentageLimit = percentageLimit;
173                }
174
175                @Override
176                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
177                        this.limit = (int) Math.rint(percentageLimit * data.size());
178                        return super.init(data, model);
179                }
180        }
181
182        /**
183         * Stopping condition that tests the number of matches against a percentage
184         * threshold of the whole data. If the number exceeds the threshold, then
185         * the model is considered to fit.
186         */
187        public static class ProbabilisticMinInliersStoppingCondition implements StoppingCondition {
188                private static final double DEFAULT_INLIER_IS_BAD_PROBABILITY = 0.1;
189                private static final double DEFAULT_PERCENTAGE_INLIERS = 0.25;
190                private double inlierIsBadProbability;
191                private double desiredErrorProbability;
192                private double percentageInliers;
193
194                private int numItemsToEstimate;
195                private int iteration = 0;
196                private int limit;
197                private int maxInliers = 0;
198                private double currentProb;
199                private int numDataItems;
200
201                /**
202                 * Default constructor.
203                 *
204                 * @param desiredErrorProbability
205                 *            The desired error rate
206                 * @param inlierIsBadProbability
207                 *            The probability an inlier is bad
208                 * @param percentageInliers
209                 *            The percentage of inliers in the data
210                 */
211                public ProbabilisticMinInliersStoppingCondition(double desiredErrorProbability, double inlierIsBadProbability,
212                                double percentageInliers)
213                {
214                        this.desiredErrorProbability = desiredErrorProbability;
215                        this.inlierIsBadProbability = inlierIsBadProbability;
216                        this.percentageInliers = percentageInliers;
217                }
218
219                /**
220                 * Constructor with defaults for bad inlier probability and percentage
221                 * inliers.
222                 *
223                 * @param desiredErrorProbability
224                 *            The desired error rate
225                 */
226                public ProbabilisticMinInliersStoppingCondition(double desiredErrorProbability) {
227                        this(desiredErrorProbability, DEFAULT_INLIER_IS_BAD_PROBABILITY, DEFAULT_PERCENTAGE_INLIERS);
228                }
229
230                @Override
231                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
232                        numItemsToEstimate = model.numItemsToEstimate();
233                        numDataItems = data.size();
234                        this.limit = calculateMinInliers();
235                        this.iteration = 0;
236                        this.currentProb = 1.0;
237                        this.maxInliers = 0;
238
239                        return true;
240                }
241
242                @Override
243                public boolean finalFitCondition(int numInliers) {
244                        return numInliers >= limit;
245                }
246
247                private int calculateMinInliers() {
248                        double pi, sum;
249                        int i, j;
250
251                        for (j = numItemsToEstimate + 1; j <= numDataItems; j++)
252                        {
253                                sum = 0;
254                                for (i = j; i <= numDataItems; i++)
255                                {
256                                        pi = (i - numItemsToEstimate) * Math.log(inlierIsBadProbability)
257                                                        + (numDataItems - i + numItemsToEstimate) * Math.log(1.0 - inlierIsBadProbability) +
258                                                        log_factorial(numDataItems - numItemsToEstimate) - log_factorial(i - numItemsToEstimate)
259                                                        - log_factorial(numDataItems - i);
260                                        /*
261                                         * Last three terms above are equivalent to log( n-m choose
262                                         * i-m )
263                                         */
264                                        sum += Math.exp(pi);
265                                }
266                                if (sum < desiredErrorProbability)
267                                        break;
268                        }
269                        return j;
270                }
271
272                private double log_factorial(int n) {
273                        double f = 0;
274                        int i;
275
276                        for (i = 1; i <= n; i++)
277                                f += Math.log(i);
278
279                        return f;
280                }
281
282                @Override
283                public boolean shouldStopIterations(int numInliers) {
284
285                        if (numInliers > maxInliers) {
286                                maxInliers = numInliers;
287                                percentageInliers = (double) maxInliers / numDataItems;
288
289                                // System.err.format("Updated maxInliers: %d\n", maxInliers);
290                        }
291                        currentProb = Math.pow(1.0 - Math.pow(percentageInliers, numItemsToEstimate), ++iteration);
292                        return currentProb <= this.desiredErrorProbability;
293                }
294        }
295
296        /**
297         * Stopping condition that allows the RANSAC algorithm to run until all the
298         * iterations have been exhausted. The fitData method will return true if
299         * there are at least as many inliers as datapoints required to estimate the
300         * model, and the model will be the one from the iteration that had the most
301         * inliers.
302         */
303        public static class BestFitStoppingCondition implements StoppingCondition {
304                int required;
305
306                @Override
307                public boolean init(List<?> data, EstimatableModel<?, ?> model) {
308                        required = model.numItemsToEstimate();
309                        return true;
310                }
311
312                @Override
313                public boolean shouldStopIterations(int numInliers) {
314                        return false; // just iterate until the end
315                }
316
317                @Override
318                public boolean finalFitCondition(int numInliers) {
319                        return numInliers > required; // accept the best result as a good
320                        // fit if there are enough inliers
321                }
322        }
323
324        protected M model;
325        protected ResidualCalculator<I, D, M> errorModel;
326        protected DistanceCheck dc;
327
328        protected int nIter;
329        protected boolean improveEstimate;
330        protected List<IndependentPair<I, D>> inliers;
331        protected List<IndependentPair<I, D>> outliers;
332        protected List<IndependentPair<I, D>> bestModelInliers;
333        protected List<IndependentPair<I, D>> bestModelOutliers;
334        protected StoppingCondition stoppingCondition;
335        protected List<? extends IndependentPair<I, D>> modelConstructionData;
336        protected CollectionSampler<IndependentPair<I, D>> sampler;
337
338        /**
339         * Create a RANSAC object with uniform random sampling for creating the
340         * subsets
341         *
342         * @param model
343         *            Model object with which to fit data
344         * @param errorModel
345         *            object to compute the error of the model
346         * @param errorThreshold
347         *            the threshold below which error is deemed acceptable for a fit
348         * @param nIterations
349         *            Maximum number of allowed iterations (L)
350         * @param stoppingCondition
351         *            the stopping condition
352         * @param impEst
353         *            True if we want to perform a final fitting of the model with
354         *            all inliers, false otherwise
355         */
356        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
357                        double errorThreshold, int nIterations,
358                        StoppingCondition stoppingCondition, boolean impEst)
359        {
360                this(model, errorModel, new ThresholdDistanceCheck(errorThreshold), nIterations, stoppingCondition, impEst);
361        }
362
363        /**
364         * Create a RANSAC object with uniform random sampling for creating the
365         * subsets
366         *
367         * @param model
368         *            Model object with which to fit data
369         * @param errorModel
370         *            object to compute the error of the model
371         * @param dc
372         *            the distance check that tests whether a point with given error
373         *            from the error model should be considered an inlier
374         * @param nIterations
375         *            Maximum number of allowed iterations (L)
376         * @param stoppingCondition
377         *            the stopping condition
378         * @param impEst
379         *            True if we want to perform a final fitting of the model with
380         *            all inliers, false otherwise
381         */
382        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
383                        DistanceCheck dc, int nIterations,
384                        StoppingCondition stoppingCondition, boolean impEst)
385        {
386                this(model, errorModel, dc, nIterations, stoppingCondition, impEst, new UniformSampler<IndependentPair<I, D>>());
387        }
388
389        /**
390         * Create a RANSAC object
391         *
392         * @param model
393         *            Model object with which to fit data
394         * @param errorModel
395         *            object to compute the error of the model
396         * @param errorThreshold
397         *            the threshold below which error is deemed acceptable for a fit
398         * @param nIterations
399         *            Maximum number of allowed iterations (L)
400         * @param stoppingCondition
401         *            the stopping condition
402         * @param impEst
403         *            True if we want to perform a final fitting of the model with
404         *            all inliers, false otherwise
405         * @param sampler
406         *            the sampling algorithm for selecting random subsets
407         */
408        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
409                        double errorThreshold, int nIterations,
410                        StoppingCondition stoppingCondition, boolean impEst, CollectionSampler<IndependentPair<I, D>> sampler)
411        {
412                this(model, errorModel, new ThresholdDistanceCheck(errorThreshold), nIterations, stoppingCondition, impEst,
413                                sampler);
414        }
415
416        /**
417         * Create a RANSAC object
418         *
419         * @param model
420         *            Model object with which to fit data
421         * @param errorModel
422         *            object to compute the error of the model
423         * @param dc
424         *            the distance check that tests whether a point with given error
425         *            from the error model should be considered an inlier
426         * @param nIterations
427         *            Maximum number of allowed iterations (L)
428         * @param stoppingCondition
429         *            the stopping condition
430         * @param impEst
431         *            True if we want to perform a final fitting of the model with
432         *            all inliers, false otherwise
433         * @param sampler
434         *            the sampling algorithm for selecting random subsets
435         */
436        public RANSAC(M model, ResidualCalculator<I, D, M> errorModel,
437                        DistanceCheck dc, int nIterations,
438                        StoppingCondition stoppingCondition, boolean impEst, CollectionSampler<IndependentPair<I, D>> sampler)
439        {
440                this.stoppingCondition = stoppingCondition;
441                this.model = model;
442                this.errorModel = errorModel;
443                this.dc = dc;
444                nIter = nIterations;
445                improveEstimate = impEst;
446
447                inliers = new ArrayList<IndependentPair<I, D>>();
448                outliers = new ArrayList<IndependentPair<I, D>>();
449                this.sampler = sampler;
450        }
451
452        @Override
453        public boolean fitData(final List<? extends IndependentPair<I, D>> data)
454        {
455                int l;
456                final int M = model.numItemsToEstimate();
457
458                bestModelInliers = null;
459                bestModelOutliers = null;
460
461                if (data.size() < M || !stoppingCondition.init(data, model)) {
462                        return false; // there are not enough points to create a model, or
463                        // init failed
464                }
465
466                sampler.setCollection(data);
467
468                for (l = 0; l < nIter; l++) {
469                        // 1
470                        final List<? extends IndependentPair<I, D>> rnd = sampler.sample(M);
471                        this.setModelConstructionData(rnd);
472
473                        // 2
474                        if (!model.estimate(rnd))
475                                continue; // bad estimate
476
477                        errorModel.setModel(model);
478
479                        // 3
480                        int K = 0;
481                        inliers.clear();
482                        outliers.clear();
483                        for (final IndependentPair<I, D> dp : data) {
484                                if (dc.check(errorModel.computeResidual(dp)))
485                                {
486                                        K++;
487                                        inliers.add(dp);
488                                } else {
489                                        outliers.add(dp);
490                                }
491                        }
492
493                        if (bestModelInliers == null || inliers.size() >= bestModelInliers.size()) {
494                                // copy
495                                bestModelInliers = new ArrayList<IndependentPair<I, D>>(inliers);
496                                bestModelOutliers = new ArrayList<IndependentPair<I, D>>(outliers);
497                        }
498
499                        // 4
500                        if (stoppingCondition.shouldStopIterations(K)) {
501                                // generate "best" fit from all the iterations
502                                inliers = bestModelInliers;
503                                outliers = bestModelOutliers;
504
505                                if (improveEstimate) {
506                                        if (inliers.size() >= model.numItemsToEstimate())
507                                                if (!model.estimate(inliers))
508                                                        return false;
509                                }
510                                final boolean stopping = stoppingCondition.finalFitCondition(inliers.size());
511                                // System.err.format("done: %b\n",stopping);
512                                return stopping;
513                        }
514                        // 5
515                        // ...repeat...
516                }
517
518                // generate "best" fit from all the iterations
519                if (bestModelInliers == null) {
520                        bestModelInliers = new ArrayList<IndependentPair<I, D>>();
521                        bestModelOutliers = new ArrayList<IndependentPair<I, D>>();
522                }
523
524                inliers = bestModelInliers;
525                outliers = bestModelOutliers;
526
527                if (bestModelInliers.size() >= M)
528                        if (!model.estimate(bestModelInliers))
529                                return false;
530
531                // 6 - fail
532                return stoppingCondition.finalFitCondition(inliers.size());
533        }
534
535        @Override
536        public List<? extends IndependentPair<I, D>> getInliers()
537                        {
538                return inliers;
539                        }
540
541        @Override
542        public List<? extends IndependentPair<I, D>> getOutliers()
543                        {
544                return outliers;
545                        }
546
547        /**
548         * @return maximum number of allowed iterations
549         */
550        public int getMaxIterations() {
551                return nIter;
552        }
553
554        /**
555         * Set the maximum number of allowed iterations
556         *
557         * @param nIter
558         *            maximum number of allowed iterations
559         */
560        public void setMaxIterations(int nIter) {
561                this.nIter = nIter;
562        }
563
564        @Override
565        public M getModel() {
566                return model;
567        }
568
569        /**
570         * Set the underlying model being fitted
571         *
572         * @param model
573         *            the model
574         */
575        public void setModel(M model) {
576                this.model = model;
577        }
578
579        /**
580         * @return whether RANSAC should attempt to improve the model using all
581         *         inliers as data
582         */
583        public boolean isImproveEstimate() {
584                return improveEstimate;
585        }
586
587        /**
588         * Set whether RANSAC should attempt to improve the model using all inliers
589         * as data
590         *
591         * @param improveEstimate
592         *            should RANSAC attempt to improve the model using all inliers
593         *            as data
594         */
595        public void setImproveEstimate(boolean improveEstimate) {
596                this.improveEstimate = improveEstimate;
597        }
598
599        /**
600         * Set the data used to construct the model
601         *
602         * @param modelConstructionData
603         */
604        public void setModelConstructionData(List<? extends IndependentPair<I, D>> modelConstructionData) {
605                this.modelConstructionData = modelConstructionData;
606        }
607
608        /**
609         * @return The data used to construct the model.
610         */
611        public List<? extends IndependentPair<I, D>> getModelConstructionData() {
612                return modelConstructionData;
613        }
614
615        @Override
616        public int numItemsToEstimate() {
617                return model.numItemsToEstimate();
618        }
619}