/*
 * Decompiled with CFR 0.152.
 */
package beast.core;

import beast.core.Citation;
import beast.core.Description;
import beast.core.Distribution;
import beast.core.Input;
import beast.core.Logger;
import beast.core.Operator;
import beast.core.OperatorSchedule;
import beast.core.Runnable;
import beast.core.State;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.util.CompoundDistribution;
import beast.core.util.Evaluator;
import beast.core.util.Log;
import beast.util.Randomizer;
import java.io.IOException;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import javax.xml.parsers.ParserConfigurationException;
import org.xml.sax.SAXException;

@Description(value="MCMC chain. This is the main element that controls which posterior to calculate, how long to run the chain and all other properties, which operators to apply on the state space and where to log results.")
@Citation(value="Bouckaert RR, Heled J, Kuehnert D, Vaughan TG, Wu C-H, Xie D, Suchard MA,\n  Rambaut A, Drummond AJ (2014) BEAST 2: A software platform for Bayesian\n  evolutionary analysis. PLoS Computational Biology 10(4): e1003537", year=2014, firstAuthorSurname="bouckaert", DOI="10.1371/journal.pcbi.1003537")
public class MCMC
extends Runnable {
    public final Input<Integer> chainLengthInput = new Input("chainLength", "Length of the MCMC chain i.e. number of samples taken in main loop", Input.Validate.REQUIRED);
    public final Input<State> startStateInput = new Input("state", "elements of the state space");
    public final Input<List<StateNodeInitialiser>> initialisersInput = new Input("init", "one or more state node initilisers used for determining the start state of the chain", new ArrayList());
    public final Input<Integer> storeEveryInput = new Input<Integer>("storeEvery", "store the state to disk every X number of samples so that we can resume computation later on if the process failed half-way.", -1);
    public final Input<Integer> burnInInput = new Input<Integer>("preBurnin", "Number of burn in samples taken before entering the main loop", 0);
    public final Input<Integer> numInitializationAttempts = new Input<Integer>("numInitializationAttempts", "Number of initialization attempts before failing (default=10)", 10);
    public final Input<Distribution> posteriorInput = new Input("distribution", "probability distribution to sample over (e.g. a posterior)", Input.Validate.REQUIRED);
    public final Input<List<Operator>> operatorsInput = new Input("operator", "operator for generating proposals in MCMC state space", new ArrayList());
    public final Input<List<Logger>> loggersInput = new Input("logger", "loggers for reporting progress of MCMC chain", new ArrayList(), Input.Validate.REQUIRED);
    public final Input<Boolean> sampleFromPriorInput = new Input<Boolean>("sampleFromPrior", "whether to ignore the likelihood when sampling (default false). The distribution with id 'likelihood' in the posterior input will be ignored when this flag is set.", false);
    public final Input<OperatorSchedule> operatorScheduleInput = new Input<OperatorSchedule>("operatorschedule", "specify operator selection and optimisation schedule", new OperatorSchedule());
    protected OperatorSchedule operatorSchedule;
    protected State state;
    protected final int NR_OF_DEBUG_SAMPLES = 2000;
    protected int storeEvery;
    private static final boolean printDebugInfo = false;
    protected double logAlpha;
    protected boolean debugFlag;
    protected double oldLogLikelihood;
    protected double newLogLikelihood;
    protected int burnIn;
    protected int chainLength;
    protected Distribution posterior;
    protected List<Logger> loggers;

    @Override
    public void initAndValidate() {
        Object object;
        Log.info.println("===============================================================================");
        Log.info.println("Citations for this model:");
        Log.info.println(this.getCitations());
        Log.info.println("===============================================================================");
        this.operatorSchedule = this.operatorScheduleInput.get();
        for (Operator object22 : this.operatorsInput.get()) {
            this.operatorSchedule.addOperator(object22);
        }
        if (this.sampleFromPriorInput.get().booleanValue()) {
            if (this.posteriorInput.get() instanceof CompoundDistribution) {
                object = (CompoundDistribution)this.posteriorInput.get();
                List<Distribution> list = ((CompoundDistribution)object).pDistributions.get();
                int n = list.size();
                for (int i = 0; i < n; ++i) {
                    Distribution distribution = list.get(i);
                    String string = distribution.getID();
                    if (string == null || !string.equals("likelihood")) continue;
                    list.remove(distribution);
                    break;
                }
                if (list.size() == n) {
                    throw new RuntimeException("Sample from prior flag is set, but distribution with id 'likelihood' is not an input to posterior.");
                }
            } else {
                throw new RuntimeException("Don't know how to sample from prior since posterior is not a compound distribution. Suggestion: set sampleFromPrior flag to false.");
            }
        }
        if (this.restoreFromFile) {
            object = new HashSet();
            for (StateNodeInitialiser stateNodeInitialiser : this.initialisersInput.get()) {
                ArrayList<StateNode> arrayList = new ArrayList<StateNode>(1);
                stateNodeInitialiser.getInitialisedStateNodes(arrayList);
                for (StateNode stateNode : arrayList) {
                    if (!((HashSet)object).contains(stateNode)) continue;
                    throw new RuntimeException("Trying to initialise stateNode (id=" + stateNode.getID() + ") more than once. " + "Remove an initialiser from MCMC to fix this.");
                }
                ((AbstractCollection)object).addAll(arrayList);
            }
        }
        object = new HashSet();
        for (Operator operator : this.operatorsInput.get()) {
            for (StateNode stateNode : operator.listStateNodes()) {
                ((HashSet)object).add(stateNode);
            }
        }
        if (this.startStateInput.get() != null) {
            this.state = this.startStateInput.get();
            if (this.storeEveryInput.get() > 0) {
                this.state.m_storeEvery.setValue(this.storeEveryInput.get(), this.state);
            }
        } else {
            this.state = new State();
            Iterator iterator = ((HashSet)object).iterator();
            while (iterator.hasNext()) {
                StateNode stateNode = (StateNode)iterator.next();
                this.state.stateNodeInput.setValue(stateNode, this.state);
            }
            this.state.m_storeEvery.setValue(this.storeEveryInput.get(), this.state);
        }
        this.storeEvery = this.storeEveryInput.get() > 0 ? this.storeEveryInput.get().intValue() : this.state.m_storeEvery.get().intValue();
        this.state.initialise();
        this.state.setPosterior(this.posteriorInput.get());
        List<StateNode> list = this.state.stateNodeInput.get();
        for (Operator operator : this.operatorsInput.get()) {
            List<StateNode> list2 = operator.listStateNodes();
            if (list2.size() == 0) {
                throw new RuntimeException("Operator " + operator.getID() + " has no state nodes in the state. " + "Each operator should operate on at least one estimated state node in the state. " + "Remove the operator or add its statenode(s) to the state and/or set estimate='true'.");
            }
            for (StateNode stateNode : operator.listStateNodes()) {
                if (list.contains(stateNode)) continue;
                throw new RuntimeException("Operator " + operator.getID() + " has a statenode " + stateNode.getID() + " in its inputs that is missing from the state.");
            }
        }
        if (this.operatorsInput.get().size() == 0) {
            Log.warning.println("Warning: at least one operator required to run the MCMC properly, but none found.");
        }
        for (StateNode stateNode : list) {
            if (((HashSet)object).contains(stateNode)) continue;
            Log.warning.println("Warning: state contains a node " + stateNode.getID() + " for which there is no operator.");
        }
    }

    public void log(int n) {
        for (Logger logger : this.loggers) {
            logger.log(n);
        }
    }

    public void close() {
        for (Logger logger : this.loggers) {
            logger.close();
        }
    }

    @Override
    public void run() throws IOException, SAXException, ParserConfigurationException {
        this.state.initAndValidate();
        this.state.setStateFileName(this.stateFileName);
        this.operatorSchedule.setStateFileName(this.stateFileName);
        this.burnIn = this.burnInInput.get();
        this.chainLength = this.chainLengthInput.get();
        int n = 0;
        this.state.setEverythingDirty(true);
        this.posterior = this.posteriorInput.get();
        if (this.restoreFromFile) {
            this.state.restoreFromFile();
            this.operatorSchedule.restoreFromFile();
            this.burnIn = 0;
            this.oldLogLikelihood = this.state.robustlyCalcPosterior(this.posterior);
        } else {
            do {
                for (StateNodeInitialiser stateNodeInitialiser : this.initialisersInput.get()) {
                    stateNodeInitialiser.initStateNodes();
                }
                this.oldLogLikelihood = this.state.robustlyCalcPosterior(this.posterior);
            } while (Double.isInfinite(this.oldLogLikelihood) && ++n < this.numInitializationAttempts.get());
        }
        long l = System.currentTimeMillis();
        this.state.storeCalculationNodes();
        this.logAlpha = 0.0;
        this.debugFlag = Boolean.valueOf(System.getProperty("beast.debug"));
        Log.info.println("Start likelihood: " + this.oldLogLikelihood + " " + (n > 1 ? "after " + n + " initialisation attempts" : ""));
        if (Double.isInfinite(this.oldLogLikelihood) || Double.isNaN(this.oldLogLikelihood)) {
            this.reportLogLikelihoods(this.posterior, "");
            throw new RuntimeException("Could not find a proper state to initialise. Perhaps try another seed.");
        }
        this.loggers = this.loggersInput.get();
        Collections.sort(this.loggers, (logger, logger2) -> {
            if (logger.isLoggingToStdout()) {
                return logger2.isLoggingToStdout() ? 0 : 1;
            }
            return logger2.isLoggingToStdout() ? -1 : 0;
        });
        boolean bl = false;
        boolean bl2 = false;
        for (Logger logger3 : this.loggers) {
            if (logger3.isLoggingToStdout()) {
                bl = true;
            }
            if (logger3.getID() == null || !logger3.getID().equals("screenlog")) continue;
            bl2 = true;
        }
        if (!bl) {
            Log.warning.println("WARNING: If nothing seems to be happening on screen this is because none of the loggers give feedback to screen.");
            if (bl2) {
                Log.warning.println("WARNING: This happens when a filename  is specified for the 'screenlog' logger.");
                Log.warning.println("WARNING: To get feedback to screen, leave the filename for screenlog blank.");
                Log.warning.println("WARNING: Otherwise, the screenlog is saved into the specified file.");
            }
        }
        for (Logger logger3 : this.loggers) {
            logger3.init();
        }
        this.doLoop();
        Log.info.println();
        this.operatorSchedule.showOperatorRates(System.out);
        Log.info.println();
        long l2 = System.currentTimeMillis();
        Log.info.println("Total calculation time: " + (double)(l2 - l) / 1000.0 + " seconds");
        this.close();
        Log.warning.println("End likelihood: " + this.oldLogLikelihood);
        this.state.storeToFile(this.chainLength);
        this.operatorSchedule.storeToFile();
    }

    protected void doLoop() throws IOException {
        int n = 0;
        boolean bl = this.posterior.isStochastic();
        if (this.burnIn > 0) {
            Log.warning.println("Please wait while BEAST takes " + this.burnIn + " pre-burnin samples");
        }
        for (int i = -this.burnIn; i <= this.chainLength; ++i) {
            Operator operator = this.propagateState(i);
            if (this.debugFlag && i % 3 == 0 || i % 10000 == 0) {
                double d;
                double d2 = bl ? this.posterior.getNonStochasticLogP() : this.oldLogLikelihood;
                double d3 = d = bl ? this.state.robustlyCalcNonStochasticPosterior(this.posterior) : this.state.robustlyCalcPosterior(this.posterior);
                if (this.isTooDifferent(d, d2)) {
                    this.reportLogLikelihoods(this.posterior, "");
                    Log.err.println("At sample " + i + "\nLikelihood incorrectly calculated: " + d2 + " != " + d + "(" + (d2 - d) + ")" + " Operator: " + operator.getClass().getName());
                }
                if (i > 6000) {
                    this.debugFlag = false;
                    if (this.isTooDifferent(d, d2)) {
                        if (++n > 100) {
                            Log.err.println("Too many corrections. There is something seriously wrong that cannot be corrected");
                            this.state.storeToFile(i);
                            this.operatorSchedule.storeToFile();
                            System.exit(1);
                        }
                        this.oldLogLikelihood = this.state.robustlyCalcPosterior(this.posterior);
                    }
                } else if (this.isTooDifferent(d, d2)) {
                    this.state.storeToFile(i);
                    this.operatorSchedule.storeToFile();
                    System.exit(1);
                }
            } else if (i >= 0) {
                operator.optimize(this.logAlpha);
            }
            this.callUserFunction(i);
            if (this.storeEvery > 0 && (i + 1) % this.storeEvery == 0 || i == this.chainLength) {
                this.state.robustlyCalcNonStochasticPosterior(this.posterior);
                this.state.storeToFile(i);
                this.operatorSchedule.storeToFile();
            }
            if (this.posterior.getCurrentLogP() != Double.POSITIVE_INFINITY) continue;
            throw new RuntimeException("Encountered a positive infinite posterior. This is a sign there may be numeric instability in the model.");
        }
        if (n > 0) {
            Log.err.println("\n\nNB: " + n + " posterior calculation corrections were required. This analysis may not be valid!\n\n");
        }
    }

    protected Operator propagateState(final int n) {
        double d;
        this.state.store(n);
        Operator operator = this.operatorSchedule.selectOperator();
        final Distribution distribution = operator.getEvaluatorDistribution();
        Evaluator evaluator = null;
        if (distribution != null) {
            evaluator = new Evaluator(){

                @Override
                public double evaluate() {
                    double d = 0.0;
                    MCMC.this.state.storeCalculationNodes();
                    MCMC.this.state.checkCalculationNodesDirtiness();
                    try {
                        d = distribution.calculateLogP();
                    }
                    catch (Exception exception) {
                        exception.printStackTrace();
                        System.exit(1);
                    }
                    MCMC.this.state.restore();
                    MCMC.this.state.store(n);
                    return d;
                }
            };
        }
        if ((d = operator.proposal(evaluator)) != Double.NEGATIVE_INFINITY) {
            if (operator.requiresStateInitialisation()) {
                this.state.storeCalculationNodes();
                this.state.checkCalculationNodesDirtiness();
            }
            this.newLogLikelihood = this.posterior.calculateLogP();
            this.logAlpha = this.newLogLikelihood - this.oldLogLikelihood + d;
            if (this.logAlpha >= 0.0 || Randomizer.nextDouble() < Math.exp(this.logAlpha)) {
                this.oldLogLikelihood = this.newLogLikelihood;
                this.state.acceptCalculationNodes();
                if (n >= 0) {
                    operator.accept();
                }
            } else {
                if (n >= 0) {
                    operator.reject(this.newLogLikelihood == Double.NEGATIVE_INFINITY ? -1 : 0);
                }
                this.state.restore();
                this.state.restoreCalculationNodes();
            }
            this.state.setEverythingDirty(false);
        } else {
            if (n >= 0) {
                operator.reject(-2);
            }
            this.state.restore();
            if (!operator.requiresStateInitialisation()) {
                this.state.setEverythingDirty(false);
                this.state.restoreCalculationNodes();
            }
        }
        this.log(n);
        return operator;
    }

    private boolean isTooDifferent(double d, double d2) {
        return Math.abs(d - d2) > 1.0E-6;
    }

    protected void reportLogLikelihoods(Distribution distribution, String string) {
        double d = distribution.logP;
        double d2 = distribution.storedLogP;
        String string2 = d == d2 ? "" : "  **";
        Log.info.println(string + "P(" + distribution.getID() + ") = " + d + " (was " + d2 + ")" + string2);
        if (distribution instanceof CompoundDistribution) {
            for (Distribution distribution2 : ((CompoundDistribution)distribution).pDistributions.get()) {
                this.reportLogLikelihoods(distribution2, string + "\t");
            }
        }
    }

    protected void callUserFunction(int n) {
    }

    public double robustlyCalcPosterior(Distribution distribution) {
        return this.state.robustlyCalcPosterior(distribution);
    }

    public double robustlyCalcNonStochasticPosterior(Distribution distribution) {
        return this.state.robustlyCalcNonStochasticPosterior(distribution);
    }
}

