/*
 * Decompiled with CFR 0.152.
 */
package beast.evolution.likelihood;

import beast.app.BeastMCMC;
import beast.app.beauti.Beauti;
import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.FilteredAlignment;
import beast.evolution.likelihood.GenericTreeLikelihood;
import beast.evolution.likelihood.TreeLikelihood;
import beast.evolution.sitemodel.SiteModelInterface;
import beast.evolution.substitutionmodel.SubstitutionModel;
import beast.evolution.tree.TreeInterface;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;

@Description(value="Calculates the likelihood of sequence data on a beast.tree given a site and substitution model using a variant of the 'peeling algorithm'. For details, seeFelsenstein, Joseph (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. J Mol Evol 17 (6): 368-376.")
public class ThreadedTreeLikelihood
extends GenericTreeLikelihood {
    public final Input<Boolean> useAmbiguitiesInput = new Input<Boolean>("useAmbiguities", "flag to indicate leafs that sites containing ambiguous states should be handled instead of ignored (the default)", false);
    public final Input<Integer> maxNrOfThreadsInput = new Input<Integer>("threads", "maximum number of threads to use, if less than 1 the number of threads in BeastMCMC is used (default -1)", -1);
    public final Input<String> proportionsInput = new Input("proportions", "specifies proportions of patterns used per thread as space delimited string. This is useful when using a mixture of BEAGLE devices that run at different speeds, e.g GPU and CPU. The string is duplicated if there are more threads than proportions specified. For example, '1 2' as well as '33 66' with 2 threads specifies that the first thread gets a third of the patterns and the second two thirds. With 3 threads, it is interpreted as '1 2 1' = 25%, 50%, 25% and with 7 threads it is '1 2 1 2 1 2 1' = 10% 20% 10% 20% 10% 20% 10%. If not specified, all threads get the same proportion of patterns.");
    public final Input<Scaling> scalingInput = new Input<Scaling>("scaling", "type of scaling to use, one of " + Arrays.toString((Object[])Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values());
    private final Input<List<TreeLikelihood>> likelihoodsInput = new Input("*", "", new ArrayList());
    private TreeLikelihood[] treelikelihood;
    private ExecutorService pool = null;
    private final List<Callable<Double>> likelihoodCallers = new ArrayList<Callable<Double>>();
    private int threadCount;
    private double[] logPByThread;
    private int[] patternPoints;

    @Override
    public List<Input<?>> listInputs() {
        List<Input<?>> list = super.listInputs();
        if (!Beauti.isInBeauti() && System.getProperty("beast.is.junit.testing") == null) {
            list.add(this.likelihoodsInput);
        }
        return list;
    }

    @Override
    public void initAndValidate() {
        String string;
        this.threadCount = BeastMCMC.m_nThreads;
        if (this.maxNrOfThreadsInput.get() > 0) {
            this.threadCount = Math.min(this.maxNrOfThreadsInput.get(), BeastMCMC.m_nThreads);
        }
        if ((string = System.getProperty("beast.instance.count")) != null && string.length() > 0) {
            this.threadCount = Integer.parseInt(string);
        }
        this.logPByThread = new double[this.threadCount];
        if (((Alignment)this.dataInput.get()).getTaxonCount() != ((TreeInterface)this.treeInput.get()).getLeafNodeCount()) {
            throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
        }
        this.treelikelihood = new TreeLikelihood[this.threadCount];
        if (((Alignment)this.dataInput.get()).isAscertained) {
            Log.warning.println("Note, can only use single thread per alignment because the alignment is ascertained");
            this.threadCount = 1;
        }
        if (this.threadCount <= 1) {
            this.treelikelihood[0] = new TreeLikelihood();
            this.treelikelihood[0].setID(this.getID() + "0");
            this.treelikelihood[0].initByName("data", this.dataInput.get(), "tree", this.treeInput.get(), "siteModel", this.siteModelInput.get(), "branchRateModel", this.branchRateModelInput.get(), "useAmbiguities", this.useAmbiguitiesInput.get(), "scaling", (Object)((Object)this.scalingInput.get()) + "");
            this.treelikelihood[0].getOutputs().add(this);
            this.likelihoodsInput.get().add(this.treelikelihood[0]);
        } else {
            this.pool = Executors.newFixedThreadPool(this.threadCount);
            this.calcPatternPoints(((Alignment)this.dataInput.get()).getSiteCount());
            for (int i = 0; i < this.threadCount; ++i) {
                Alignment alignment = (Alignment)this.dataInput.get();
                String string2 = this.patternPoints[i] + 1 + "-" + this.patternPoints[i + 1];
                if (alignment.isAscertained) {
                    string2 = string2 + alignment.excludefromInput.get() + "-" + alignment.excludetoInput.get() + "," + string2;
                }
                this.treelikelihood[i] = new TreeLikelihood();
                this.treelikelihood[i].setID(this.getID() + i);
                this.treelikelihood[i].getOutputs().add(this);
                this.likelihoodsInput.get().add(this.treelikelihood[i]);
                FilteredAlignment filteredAlignment = new FilteredAlignment();
                if (i == 0 && this.dataInput.get() instanceof FilteredAlignment && ((FilteredAlignment)this.dataInput.get()).constantSiteWeightsInput.get() != null) {
                    filteredAlignment.initByName("data", this.dataInput.get(), "filter", string2, "constantSiteWeights", ((FilteredAlignment)this.dataInput.get()).constantSiteWeightsInput.get());
                } else {
                    filteredAlignment.initByName("data", this.dataInput.get(), "filter", string2);
                }
                this.treelikelihood[i].initByName("data", filteredAlignment, "tree", this.treeInput.get(), "siteModel", this.duplicate((BEASTInterface)this.siteModelInput.get(), i), "branchRateModel", this.duplicate((BEASTInterface)this.branchRateModelInput.get(), i), "useAmbiguities", this.useAmbiguitiesInput.get(), "scaling", (Object)((Object)this.scalingInput.get()) + "");
                this.likelihoodCallers.add(new TreeLikelihoodCaller(this.treelikelihood[i], i));
            }
        }
    }

    private Object duplicate(BEASTInterface bEASTInterface, int n) {
        BEASTInterface bEASTInterface2;
        if (bEASTInterface == null) {
            return null;
        }
        try {
            bEASTInterface2 = (BEASTInterface)bEASTInterface.getClass().newInstance();
            bEASTInterface2.setID(bEASTInterface.getID() + "_" + n);
        }
        catch (IllegalAccessException | InstantiationException reflectiveOperationException) {
            reflectiveOperationException.printStackTrace();
            throw new RuntimeException("Programmer error: every object in the model should have a default constructor that is publicly accessible: " + bEASTInterface.getClass().getName());
        }
        for (Input<?> input : bEASTInterface.listInputs()) {
            if (input.get() == null) continue;
            if (input.get() instanceof List) {
                for (Object e : (List)input.get()) {
                    if (!(e instanceof BEASTInterface)) continue;
                    bEASTInterface2.setInputValue(input.getName(), e);
                }
                continue;
            }
            if (input.get() instanceof SubstitutionModel) {
                BEASTInterface bEASTInterface3 = (BEASTInterface)this.duplicate((BEASTInterface)input.get(), n);
                bEASTInterface2.setInputValue(input.getName(), bEASTInterface3);
                continue;
            }
            bEASTInterface2.setInputValue(input.getName(), input.get());
        }
        bEASTInterface2.initAndValidate();
        return bEASTInterface2;
    }

    private void calcPatternPoints(int n) {
        this.patternPoints = new int[this.threadCount + 1];
        if (this.proportionsInput.get() == null) {
            int n2 = n / this.threadCount;
            for (int i = 0; i < this.threadCount - 1; ++i) {
                this.patternPoints[i + 1] = n2 * (i + 1);
            }
            this.patternPoints[this.threadCount] = n;
        } else {
            String[] stringArray = this.proportionsInput.get().split("\\s+");
            double[] dArray = new double[this.threadCount];
            for (int i = 0; i < this.threadCount; ++i) {
                dArray[i] = Double.parseDouble(stringArray[i % stringArray.length]);
            }
            double d = 0.0;
            for (double d2 : dArray) {
                d += d2;
            }
            int n3 = 0;
            while (n3 < this.threadCount) {
                int n4 = n3++;
                dArray[n4] = dArray[n4] / d;
            }
            for (n3 = 1; n3 < this.threadCount; ++n3) {
                int n5 = n3;
                dArray[n5] = dArray[n5] + dArray[n3 - 1];
            }
            for (n3 = 0; n3 < this.threadCount; ++n3) {
                this.patternPoints[n3 + 1] = (int)(dArray[n3] * (double)n + 0.5);
            }
        }
    }

    @Override
    public void sample(State state, Random random) {
        throw new UnsupportedOperationException("Can't sample a fixed alignment!");
    }

    @Override
    public double calculateLogP() {
        this.logP = this.calculateLogPByBeagle();
        return this.logP;
    }

    private double calculateLogPByBeagle() {
        try {
            if (this.threadCount > 1) {
                this.pool.invokeAll(this.likelihoodCallers);
                this.logP = 0.0;
                for (double d : this.logPByThread) {
                    this.logP += d;
                }
            } else {
                this.logP = this.treelikelihood[0].calculateLogP();
            }
        }
        catch (InterruptedException | RejectedExecutionException exception) {
            exception.printStackTrace();
            System.exit(0);
        }
        return this.logP;
    }

    public double[] getPatternLogLikelihoods() {
        double[] dArray = new double[((Alignment)this.dataInput.get()).getPatternCount()];
        int n = 0;
        for (TreeLikelihood treeLikelihood : this.treelikelihood) {
            double[] dArray2 = treeLikelihood.getPatternLogLikelihoods();
            System.arraycopy(dArray2, 0, dArray, n, dArray2.length);
            n += dArray2.length;
        }
        return dArray;
    }

    @Override
    protected boolean requiresRecalculation() {
        boolean bl = false;
        for (TreeLikelihood treeLikelihood : this.treelikelihood) {
            bl |= treeLikelihood.requiresRecalculation();
        }
        return bl;
    }

    @Override
    public void store() {
        super.store();
    }

    @Override
    public void restore() {
        super.restore();
    }

    @Override
    public List<String> getArguments() {
        return Collections.singletonList(((Alignment)this.dataInput.get()).getID());
    }

    @Override
    public List<String> getConditions() {
        return ((SiteModelInterface.Base)this.siteModelInput.get()).getConditions();
    }

    class TreeLikelihoodCaller
    implements Callable<Double> {
        private final TreeLikelihood likelihood;
        private final int threadNr;

        public TreeLikelihoodCaller(TreeLikelihood treeLikelihood, int n) {
            this.likelihood = treeLikelihood;
            this.threadNr = n;
        }

        @Override
        public Double call() throws Exception {
            try {
                ((ThreadedTreeLikelihood)ThreadedTreeLikelihood.this).logPByThread[this.threadNr] = this.likelihood.calculateLogP();
            }
            catch (Exception exception) {
                System.err.println("Something went wrong in thread " + this.threadNr);
                exception.printStackTrace();
                System.exit(0);
            }
            return ThreadedTreeLikelihood.this.logPByThread[this.threadNr];
        }
    }

    static enum Scaling {
        none,
        always,
        _default;

    }
}

