/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.factorAnalysis;

import dr.evomodel.treedatalikelihood.continuous.HashedMissingArray;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.distribution.NormalStatisticsProvider;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.factorAnalysis.FactorAnalysisOperatorAdaptor;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;

public class NewLoadingsGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator,
Reportable {
    private final boolean useInnerProductCache;
    private Map<HashedMissingArray, DenseMatrix64F> precisionMatrixMap = new HashMap<HashedMissingArray, DenseMatrix64F>();
    private NormalDistributionModel workingPrior;
    private final ArrayList<double[][]> precisionArray;
    private final ArrayList<double[]> meanMidArray;
    private final ArrayList<double[]> meanArray;
    private final boolean randomScan;
    private double pathParameter = 1.0;
    private final NormalStatisticsProvider prior;
    private final double priorPrecisionWorking;
    private final FactorAnalysisOperatorAdaptor adaptor;
    private final ConstrainedSampler constrainedSampler;
    private final ColumnDimProvider columnDimProvider;
    private final double[][] observedIndicators;
    private static boolean DEBUG = false;
    private final List<Callable<Double>> drawCallers = new ArrayList<Callable<Double>>();
    private final ExecutorService pool;

    public NewLoadingsGibbsOperator(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor, NormalStatisticsProvider normalStatisticsProvider, double d, boolean bl, DistributionLikelihood distributionLikelihood, boolean bl2, int n, ConstrainedSampler constrainedSampler, ColumnDimProvider columnDimProvider, CacheProvider cacheProvider) {
        this.setWeight(d);
        this.adaptor = factorAnalysisOperatorAdaptor;
        this.prior = normalStatisticsProvider;
        if (distributionLikelihood != null) {
            this.workingPrior = (NormalDistributionModel)distributionLikelihood.getDistribution();
        }
        this.precisionArray = new ArrayList();
        this.meanMidArray = new ArrayList();
        this.meanArray = new ArrayList();
        this.randomScan = bl;
        this.constrainedSampler = constrainedSampler;
        this.columnDimProvider = columnDimProvider;
        this.priorPrecisionWorking = distributionLikelihood == null ? this.getPrecision(normalStatisticsProvider, 0) : 1.0 / (this.workingPrior.getStdev() * this.workingPrior.getStdev());
        if (bl2) {
            for (int i = 0; i < factorAnalysisOperatorAdaptor.getNumberOfTraits(); ++i) {
                int n2 = columnDimProvider.getColumnDim(i, factorAnalysisOperatorAdaptor.getNumberOfFactors());
                this.drawCallers.add(new DrawCaller(i, new double[n2][n2], new double[n2], new double[n2]));
            }
            this.pool = Executors.newFixedThreadPool(n);
        } else {
            this.pool = null;
            columnDimProvider.allocateStorage(this.precisionArray, this.meanMidArray, this.meanArray, factorAnalysisOperatorAdaptor.getNumberOfFactors());
        }
        this.useInnerProductCache = cacheProvider.useCache();
        if (this.useInnerProductCache) {
            if (bl2 && n > 1) {
                throw new IllegalArgumentException("Cannot currently parallelize cached precisions");
            }
            this.observedIndicators = this.setupObservedIndicators();
        } else {
            this.observedIndicators = null;
        }
    }

    private double getPrecision(NormalStatisticsProvider normalStatisticsProvider, int n) {
        double d = normalStatisticsProvider.getNormalSD(n);
        return 1.0 / (d * d);
    }

    private double[][] setupObservedIndicators() {
        double[][] dArray = new double[this.adaptor.getNumberOfTraits()][this.adaptor.getNumberOfTaxa()];
        for (int i = 0; i < this.adaptor.getNumberOfTraits(); ++i) {
            for (int j = 0; j < this.adaptor.getNumberOfTaxa(); ++j) {
                if (!this.adaptor.isNotMissing(i, j)) continue;
                dArray[i][j] = 1.0;
            }
        }
        return dArray;
    }

    private void getPrecisionOfTruncated(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor, int n, int n2, double[][] dArray) {
        int n3;
        int n4;
        D1Matrix64F d1Matrix64F = null;
        HashedMissingArray hashedMissingArray = null;
        if (this.useInnerProductCache) {
            double[] dArray2 = this.observedIndicators[n2];
            hashedMissingArray = new HashedMissingArray(dArray2);
            d1Matrix64F = this.precisionMatrixMap.get(hashedMissingArray);
        }
        if (!this.useInnerProductCache || d1Matrix64F == null) {
            n4 = factorAnalysisOperatorAdaptor.getNumberOfTaxa();
            for (n3 = 0; n3 < n; ++n3) {
                for (int i = n3; i < n; ++i) {
                    double d = 0.0;
                    for (int j = 0; j < n4; ++j) {
                        if (!factorAnalysisOperatorAdaptor.isNotMissing(n2, j)) continue;
                        d += factorAnalysisOperatorAdaptor.getFactorValue(n3, j) * factorAnalysisOperatorAdaptor.getFactorValue(i, j);
                    }
                    dArray[n3][i] = d;
                    if (n3 == i) continue;
                    dArray[i][n3] = d;
                }
            }
            if (this.useInnerProductCache) {
                this.precisionMatrixMap.put(hashedMissingArray, new DenseMatrix64F(dArray));
            }
        } else {
            for (n4 = 0; n4 < n; ++n4) {
                System.arraycopy(d1Matrix64F.getData(), n4 * n, dArray[n4], 0, n);
            }
        }
        for (n4 = 0; n4 < n; ++n4) {
            for (n3 = n4; n3 < n; ++n3) {
                double[] dArray3 = dArray[n4];
                int n5 = n3;
                dArray3[n5] = dArray3[n5] * this.adaptor.getColumnPrecision(n2);
                if (n4 == n3) {
                    dArray[n4][n3] = dArray[n4][n3] * this.pathParameter + this.getAdjustedPriorPrecision(factorAnalysisOperatorAdaptor.getNumberOfTraits() * n4 + n2);
                    continue;
                }
                double[] dArray4 = dArray[n4];
                int n6 = n3;
                dArray4[n6] = dArray4[n6] * this.pathParameter;
                dArray[n3][n4] = dArray[n4][n3];
            }
        }
    }

    private void getTruncatedMean(int n, int n2, double[][] dArray, double[] dArray2, double[] dArray3) {
        int n3;
        double d;
        int n4;
        int n5 = this.adaptor.getNumberOfTaxa();
        for (n4 = 0; n4 < n; ++n4) {
            d = 0.0;
            for (n3 = 0; n3 < n5; ++n3) {
                if (!this.adaptor.isNotMissing(n2, n3)) continue;
                d += this.adaptor.getFactorValue(n4, n3) * this.adaptor.getDataValue(n2, n3);
            }
            n3 = this.adaptor.getNumberOfTraits() * n4 + n2;
            d *= this.adaptor.getColumnPrecision(n2);
            dArray2[n4] = d += this.prior.getNormalMean(n3) * this.getPrecision(this.prior, n3);
        }
        for (n4 = 0; n4 < n; ++n4) {
            d = 0.0;
            for (n3 = 0; n3 < n; ++n3) {
                d += dArray[n4][n3] * dArray2[n3];
            }
            dArray3[n4] = d;
        }
    }

    private void getPrecision(int n, double[][] dArray) {
        int n2 = this.adaptor.getNumberOfFactors();
        this.getPrecisionOfTruncated(this.adaptor, this.columnDimProvider.getColumnDim(n, n2), n, dArray);
    }

    private void getMean(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
        int n2 = this.adaptor.getNumberOfFactors();
        this.getTruncatedMean(this.columnDimProvider.getColumnDim(n, n2), n, dArray, dArray2, dArray3);
        int n3 = 0;
        while (n3 < dArray3.length) {
            int n4 = n3++;
            dArray3[n4] = dArray3[n4] * this.pathParameter;
        }
    }

    private void drawI(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
        this.getPrecision(n, dArray);
        double[][] dArray4 = new SymmetricMatrix(dArray).inverse().toComponents();
        double[][] dArray5 = null;
        try {
            dArray5 = new CholeskyDecomposition(dArray4).getL();
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        this.getMean(n, dArray4, dArray2, dArray3);
        double[] dArray6 = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray3, dArray5);
        this.adaptor.setLoadingsForTraitQuietly(n, dArray6);
        if (DEBUG) {
            System.err.println("draw: " + new Vector(dArray6));
        }
    }

    private void drawI(int n) {
        int n2 = this.columnDimProvider.getArrayIndex(n, this.adaptor.getNumberOfFactors());
        this.drawI(n, this.precisionArray.get(n2), this.meanMidArray.get(n2), this.meanArray.get(n2));
    }

    @Override
    public String getOperatorName() {
        return "newLoadingsGibbsOperator";
    }

    @Override
    public double doOperation() {
        if (DEBUG) {
            System.err.println("Start doOp");
        }
        this.adaptor.drawFactors();
        int n = this.adaptor.getNumberOfTraits();
        if (this.useInnerProductCache) {
            this.precisionMatrixMap.clear();
        }
        if (this.pool != null) {
            if (DEBUG) {
                System.err.println("!= poll");
            }
            try {
                this.pool.invokeAll(this.drawCallers);
                this.adaptor.fireLoadingsChanged();
            }
            catch (InterruptedException interruptedException) {
                interruptedException.printStackTrace();
            }
        } else {
            int n2;
            if (DEBUG) {
                System.err.println("inner");
            }
            if (!this.randomScan) {
                for (n2 = 0; n2 < n; ++n2) {
                    this.drawI(n2);
                }
            } else {
                n2 = MathUtils.nextInt(this.adaptor.getNumberOfTraits());
                this.drawI(n2);
            }
            this.constrainedSampler.applyConstraint(this.adaptor);
            this.adaptor.fireLoadingsChanged();
        }
        if (DEBUG) {
            for (Object object : this.meanArray) {
                System.err.println(new Vector((double[])object));
            }
            for (Object object : this.meanMidArray) {
                System.err.println(new Vector((double[])object));
            }
            Iterator<double[]> iterator = this.precisionArray.iterator();
            while (iterator.hasNext()) {
                Object object;
                object = (double[][])iterator.next();
                System.err.println(new Matrix((double[][])object));
            }
            System.err.println("End doOp");
        }
        return 0.0;
    }

    @Override
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    private double getAdjustedPriorPrecision(int n) {
        return this.getPrecision(this.prior, n) * this.pathParameter + (1.0 - this.pathParameter) * this.priorPrecisionWorking;
    }

    @Override
    public String getReport() {
        int n;
        int n2;
        int n3 = 1000000;
        int n4 = this.adaptor.getNumberOfFactors();
        int n5 = this.adaptor.getNumberOfTraits();
        int n6 = n5 * n4;
        double[] dArray = new double[n6];
        double[][] dArray2 = new double[n6][n6];
        double[] dArray3 = new double[n6];
        for (n2 = 0; n2 < n6; ++n2) {
            dArray3[n2] = this.adaptor.getLoadingsValue(n2);
        }
        for (n2 = 0; n2 < n3; ++n2) {
            this.doOperation();
            for (n = 0; n < n6; ++n) {
                int n7 = n;
                dArray[n7] = dArray[n7] + this.adaptor.getLoadingsValue(n);
                for (int i = n; i < n6; ++i) {
                    double[] dArray4 = dArray2[n];
                    int n8 = i;
                    dArray4[n8] = dArray4[n8] + this.adaptor.getLoadingsValue(n) * this.adaptor.getLoadingsValue(i);
                }
            }
            this.adaptor.fireLoadingsChanged();
        }
        this.restoreLoadings(dArray3);
        this.adaptor.fireLoadingsChanged();
        for (n2 = 0; n2 < n6; ++n2) {
            int n9 = n2;
            dArray[n9] = dArray[n9] / (double)n3;
            n = n2;
            while (n < n6) {
                double[] dArray5 = dArray2[n2];
                int n10 = n++;
                dArray5[n10] = dArray5[n10] / (double)n3;
            }
        }
        for (n2 = 0; n2 < n6; ++n2) {
            for (n = n2; n < n6; ++n) {
                dArray2[n2][n] = dArray2[n2][n] - dArray[n2] * dArray[n];
                dArray2[n][n2] = dArray2[n2][n];
            }
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append(this.getOperatorName() + "Report:\n");
        stringBuilder.append("Loadings mean:\n");
        stringBuilder.append(new Vector(dArray));
        stringBuilder.append("\n\n");
        stringBuilder.append("Loadings covariance:\n");
        stringBuilder.append(new Matrix(dArray2));
        stringBuilder.append("\n\n");
        return stringBuilder.toString();
    }

    private void restoreLoadings(double[] dArray) {
        int n = this.adaptor.getNumberOfTraits();
        int n2 = this.adaptor.getNumberOfFactors();
        double[] dArray2 = new double[n2];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n2; ++j) {
                dArray2[j] = dArray[j * n + i];
            }
            this.adaptor.setLoadingsForTraitQuietly(i, dArray2);
        }
    }

    public static enum ColumnDimProvider {
        NONE("none"){

            @Override
            int getColumnDim(int n, int n2) {
                return n2;
            }

            @Override
            int getArrayIndex(int n, int n2) {
                return 0;
            }

            @Override
            void allocateStorage(ArrayList<double[][]> arrayList, ArrayList<double[]> arrayList2, ArrayList<double[]> arrayList3, int n) {
                arrayList.add(new double[n][n]);
                arrayList2.add(new double[n]);
                arrayList3.add(new double[n]);
            }
        }
        ,
        UPPER_TRIANGULAR("upperTriangular"){

            @Override
            int getColumnDim(int n, int n2) {
                return Math.min(n + 1, n2);
            }

            @Override
            int getArrayIndex(int n, int n2) {
                return Math.min(n, n2 - 1);
            }

            @Override
            void allocateStorage(ArrayList<double[][]> arrayList, ArrayList<double[]> arrayList2, ArrayList<double[]> arrayList3, int n) {
                for (int i = 1; i <= n; ++i) {
                    arrayList.add(new double[i][i]);
                    arrayList2.add(new double[i]);
                    arrayList3.add(new double[i]);
                }
            }
        }
        ,
        HYBRID("hybrid"){

            @Override
            int getColumnDim(int n, int n2) {
                if (n == 0) {
                    return 1;
                }
                return n2;
            }

            @Override
            int getArrayIndex(int n, int n2) {
                if (n == 0) {
                    return 0;
                }
                return 1;
            }

            @Override
            void allocateStorage(ArrayList<double[][]> arrayList, ArrayList<double[]> arrayList2, ArrayList<double[]> arrayList3, int n) {
                arrayList.add(new double[1][1]);
                arrayList2.add(new double[1]);
                arrayList3.add(new double[1]);
                arrayList.add(new double[n][n]);
                arrayList2.add(new double[n]);
                arrayList3.add(new double[n]);
            }
        };

        private String name;

        abstract int getColumnDim(int var1, int var2);

        abstract int getArrayIndex(int var1, int var2);

        abstract void allocateStorage(ArrayList<double[][]> var1, ArrayList<double[]> var2, ArrayList<double[]> var3, int var4);

        private ColumnDimProvider(String string2) {
            this.name = string2;
        }

        public String getName() {
            return this.name;
        }

        public static ColumnDimProvider parse(String string) {
            string = string.toLowerCase();
            for (ColumnDimProvider columnDimProvider : ColumnDimProvider.values()) {
                if (string.compareTo(columnDimProvider.getName().toLowerCase()) != 0) continue;
                return columnDimProvider;
            }
            throw new IllegalArgumentException("Unknown dimension provider type");
        }
    }

    public static enum ConstrainedSampler {
        NONE("none"){

            @Override
            void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
            }
        }
        ,
        REFLECTION("reflection"){

            @Override
            void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
                for (int i = 0; i < factorAnalysisOperatorAdaptor.getNumberOfFactors(); ++i) {
                    factorAnalysisOperatorAdaptor.reflectLoadingsForFactor(i);
                }
            }
        };

        private String name;

        private ConstrainedSampler(String string2) {
            this.name = string2;
        }

        public String getName() {
            return this.name;
        }

        public static ConstrainedSampler parse(String string) {
            string = string.toLowerCase();
            for (ConstrainedSampler constrainedSampler : ConstrainedSampler.values()) {
                if (string.compareTo(constrainedSampler.getName()) != 0) continue;
                return constrainedSampler;
            }
            throw new IllegalArgumentException("Unknown sampler type");
        }

        abstract void applyConstraint(FactorAnalysisOperatorAdaptor var1);
    }

    class DrawCaller
    implements Callable<Double> {
        int i;
        double[][] precision;
        double[] midMean;
        double[] mean;
        private static final boolean DEBUG_PARALLEL_EVALUATION = false;

        DrawCaller(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
            this.i = n;
            this.precision = dArray;
            this.midMean = dArray2;
            this.mean = dArray3;
        }

        @Override
        public Double call() throws Exception {
            NewLoadingsGibbsOperator.this.drawI(this.i, this.precision, this.midMean, this.mean);
            return null;
        }
    }

    public static enum CacheProvider {
        USE_CACHE{

            @Override
            boolean useCache() {
                return true;
            }
        }
        ,
        NO_CACHE{

            @Override
            boolean useCache() {
                return false;
            }
        };


        abstract boolean useCache();
    }
}

