package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer;

import edu.cmu.sphinx.frontend.FloatData;
import edu.cmu.sphinx.linguist.acoustic.HMM;
import edu.cmu.sphinx.linguist.acoustic.HMMState;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianMixture;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianWeights;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.HMMManager;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.MixtureComponent;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.Senone;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM;
import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMMState;
import edu.cmu.sphinx.util.LogMath;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:edu/cmu/sphinx/linguist/acoustic/tiedstate/trainer/HMMPoolManager.class */
class HMMPoolManager {
    private HMMManager hmmManager;
    private HashMap<Object, Integer> indexMap;
    private Pool<float[]> meansPool;
    private Pool<float[]> variancePool;
    private Pool<float[][]> matrixPool;
    private GaussianWeights mixtureWeights;
    private Pool<Buffer> meansBufferPool;
    private Pool<Buffer> varianceBufferPool;
    private Pool<Buffer[]> matrixBufferPool;
    private Pool<Buffer> mixtureWeightsBufferPool;
    private Pool<Senone> senonePool;
    private LogMath logMath;
    private float logMixtureWeightFloor;
    private float logTransitionProbabilityFloor;
    private float varianceFloor;
    private float logLikelihood;
    private float currentLogLikelihood;
    private static Logger logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    public HMMPoolManager(Loader loader) throws IOException {
        loader.load();
        this.hmmManager = loader.getHMMManager();
        this.indexMap = new HashMap<>();
        this.meansPool = loader.getMeansPool();
        this.variancePool = loader.getVariancePool();
        this.mixtureWeights = loader.getMixtureWeights();
        this.matrixPool = loader.getTransitionMatrixPool();
        this.senonePool = loader.getSenonePool();
        createBuffers();
        this.logLikelihood = 0.0f;
        this.logMath = LogMath.getLogMath();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resetBuffers() {
        createBuffers();
        this.logLikelihood = 0.0f;
    }

    protected void createBuffers() {
        this.meansBufferPool = create1DPoolBuffer(this.meansPool, false);
        this.varianceBufferPool = create1DPoolBuffer(this.variancePool, false);
        this.matrixBufferPool = create2DPoolBuffer(this.matrixPool, true);
        this.mixtureWeightsBufferPool = createWeightsPoolBuffer(this.mixtureWeights);
    }

    private Pool<Buffer> create1DPoolBuffer(Pool<float[]> pool, boolean z) {
        Pool<Buffer> pool2 = new Pool<>(pool.getName());
        for (int i = 0; i < pool.size(); i++) {
            float[] fArr = pool.get(i);
            this.indexMap.put(fArr, Integer.valueOf(i));
            pool2.put(i, new Buffer(fArr.length, z, i));
        }
        return pool2;
    }

    private Pool<Buffer> createWeightsPoolBuffer(GaussianWeights gaussianWeights) {
        Pool<Buffer> pool = new Pool<>(gaussianWeights.getName());
        int statesNum = gaussianWeights.getStatesNum();
        int streamsNum = gaussianWeights.getStreamsNum();
        int gauPerState = gaussianWeights.getGauPerState();
        for (int i = 0; i < streamsNum; i++) {
            for (int i2 = 0; i2 < statesNum; i2++) {
                int i3 = (i * statesNum) + i2;
                pool.put(i3, new Buffer(gauPerState, true, i3));
            }
        }
        return pool;
    }

    private Pool<Buffer[]> create2DPoolBuffer(Pool<float[][]> pool, boolean z) {
        Pool<Buffer[]> pool2 = new Pool<>(pool.getName());
        for (int i = 0; i < pool.size(); i++) {
            float[][] fArr = pool.get(i);
            this.indexMap.put(fArr, Integer.valueOf(i));
            int length = fArr.length;
            Buffer[] bufferArr = new Buffer[length];
            for (int i2 = 0; i2 < length; i2++) {
                bufferArr[i2] = new Buffer(fArr[i2].length, z, i2);
            }
            pool2.put(i, bufferArr);
        }
        return pool2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void accumulate(int i, TrainerScore[] trainerScoreArr) {
        accumulate(i, trainerScoreArr, null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void accumulate(int i, TrainerScore[] trainerScoreArr, TrainerScore[] trainerScoreArr2) {
        TrainerScore trainerScore = trainerScoreArr[i];
        this.currentLogLikelihood = 0.0f;
        this.logLikelihood -= trainerScoreArr[0].getScalingFactor();
        SenoneHMMState senoneHMMState = (SenoneHMMState) trainerScore.getState();
        if (senoneHMMState != null) {
            if (senoneHMMState.isEmitting()) {
                int indexOf = this.senonePool.indexOf(senoneHMMState.getSenone());
                accumulateMixture(indexOf, trainerScoreArr[i]);
                accumulateTransition(indexOf, i, trainerScoreArr, trainerScoreArr2);
                return;
            }
            return;
        }
        int senoneID = trainerScore.getSenoneID();
        if (senoneID == -1) {
            accumulateMean(senoneID, trainerScoreArr[i]);
            accumulateVariance(senoneID, trainerScoreArr[i]);
            accumulateMixture(senoneID, trainerScoreArr[i]);
            accumulateTransition(senoneID, i, trainerScoreArr, trainerScoreArr2);
        }
    }

    private void accumulateMean(int i, TrainerScore trainerScore) {
        if (i == -1) {
            for (int i2 = 0; i2 < this.senonePool.size(); i2++) {
                accumulateMean(i2, trainerScore);
            }
            return;
        }
        MixtureComponent[] mixtureComponents = ((GaussianMixture) this.senonePool.get(i)).getMixtureComponents();
        for (int i3 = 0; i3 < mixtureComponents.length; i3++) {
            int intValue = this.indexMap.get(mixtureComponents[i3].getMean()).intValue();
            if (!$assertionsDisabled && intValue < 0) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && intValue != i) {
                throw new AssertionError();
            }
            Buffer buffer = this.meansBufferPool.get(intValue);
            double[] dArr = new double[((FloatData) trainerScore.getData()).getValues().length];
            double logToLinear = this.logMath.logToLinear(trainerScore.getComponentGamma()[i3] - this.currentLogLikelihood);
            for (int i4 = 0; i4 < dArr.length; i4++) {
                dArr[i4] = r0[i4] * logToLinear;
            }
            buffer.accumulate(dArr, logToLinear);
        }
    }

    private void accumulateVariance(int i, TrainerScore trainerScore) {
        if (i == -1) {
            for (int i2 = 0; i2 < this.senonePool.size(); i2++) {
                accumulateVariance(i2, trainerScore);
            }
            return;
        }
        MixtureComponent[] mixtureComponents = ((GaussianMixture) this.senonePool.get(i)).getMixtureComponents();
        for (int i3 = 0; i3 < mixtureComponents.length; i3++) {
            float[] mean = mixtureComponents[i3].getMean();
            Buffer buffer = this.varianceBufferPool.get(this.indexMap.get(mixtureComponents[i3].getVariance()).intValue());
            double[] dArr = new double[((FloatData) trainerScore.getData()).getValues().length];
            double logToLinear = this.logMath.logToLinear(trainerScore.getComponentGamma()[i3] - this.currentLogLikelihood);
            for (int i4 = 0; i4 < dArr.length; i4++) {
                dArr[i4] = r0[i4] - mean[i4];
                int i5 = i4;
                dArr[i5] = dArr[i5] * dArr[i4] * logToLinear;
            }
            buffer.accumulate(dArr, logToLinear);
        }
    }

    private void accumulateMixture(int i, TrainerScore trainerScore) {
        if (i == -1) {
            for (int i2 = 0; i2 < this.senonePool.size(); i2++) {
                accumulateMixture(i2, trainerScore);
            }
            return;
        }
        Buffer buffer = this.mixtureWeightsBufferPool.get(i);
        for (int i3 = 0; i3 < this.mixtureWeights.getGauPerState(); i3++) {
            buffer.logAccumulate(trainerScore.getComponentGamma()[i3] - this.currentLogLikelihood, i3, this.logMath);
        }
    }

    private void accumulateStateTransition(int i, TrainerScore[] trainerScoreArr, TrainerScore[] trainerScoreArr2) {
        HMMState state = trainerScoreArr[i].getState();
        if (state == null) {
            return;
        }
        int state2 = state.getState();
        SenoneHMM senoneHMM = (SenoneHMM) state.getHMM();
        float[][] transitionMatrix = senoneHMM.getTransitionMatrix();
        Buffer[] bufferArr = this.matrixBufferPool.get(this.indexMap.get(transitionMatrix).intValue());
        float[] fArr = transitionMatrix[state2];
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] != -3.4028235E38f) {
                int i3 = i + (i2 - state2);
                if (!$assertionsDisabled && trainerScoreArr2[i3].getState() != null && trainerScoreArr2[i3].getState().getHMM() != senoneHMM) {
                    throw new AssertionError();
                }
                float alpha = trainerScoreArr[i].getAlpha();
                float beta = trainerScoreArr2[i3].getBeta();
                bufferArr[state2].logAccumulate((((alpha + beta) + fArr[i2]) + trainerScoreArr2[i3].getScore()) - this.currentLogLikelihood, i2, this.logMath);
            }
        }
    }

    private void accumulateStateTransition(int i, SenoneHMM senoneHMM, float f) {
        float[][] transitionMatrix = senoneHMM.getTransitionMatrix();
        float[] fArr = transitionMatrix[i];
        Buffer[] bufferArr = this.matrixBufferPool.get(this.indexMap.get(transitionMatrix).intValue());
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] != -3.4028235E38f) {
                bufferArr[i].logAccumulate(f, i2, this.logMath);
            }
        }
    }

    private void accumulateTransition(int i, int i2, TrainerScore[] trainerScoreArr, TrainerScore[] trainerScoreArr2) {
        if (i != -1) {
            if (trainerScoreArr2 != null) {
                accumulateStateTransition(i2, trainerScoreArr, trainerScoreArr2);
                return;
            }
            return;
        }
        Iterator<HMM> it = this.hmmManager.iterator();
        while (it.hasNext()) {
            HMM next = it.next();
            for (int i3 = 0; i3 < next.getOrder(); i3++) {
                accumulateStateTransition(i3, (SenoneHMM) next, trainerScoreArr[i2].getScore());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateLogLikelihood() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float normalize() {
        normalizePool(this.meansBufferPool);
        normalizePool(this.varianceBufferPool);
        logNormalizePool(this.mixtureWeightsBufferPool);
        logNormalize2DPool(this.matrixBufferPool, this.matrixPool);
        return this.logLikelihood;
    }

    private void normalizePool(Pool<Buffer> pool) {
        if (!$assertionsDisabled && pool == null) {
            throw new AssertionError();
        }
        for (int i = 0; i < pool.size(); i++) {
            Buffer buffer = pool.get(i);
            if (buffer.wasUsed()) {
                buffer.normalize();
            }
        }
    }

    private void logNormalizePool(Pool<Buffer> pool) {
        if (!$assertionsDisabled && pool == null) {
            throw new AssertionError();
        }
        for (int i = 0; i < pool.size(); i++) {
            Buffer buffer = pool.get(i);
            if (buffer.wasUsed()) {
                buffer.logNormalize();
            }
        }
    }

    private void logNormalize2DPool(Pool<Buffer[]> pool, Pool<float[][]> pool2) {
        if (!$assertionsDisabled && pool == null) {
            throw new AssertionError();
        }
        for (int i = 0; i < pool.size(); i++) {
            Buffer[] bufferArr = pool.get(i);
            float[][] fArr = pool2.get(i);
            for (int i2 = 0; i2 < bufferArr.length; i2++) {
                if (bufferArr[i2].wasUsed()) {
                    bufferArr[i2].logNormalizeNonZero(fArr[i2]);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update() {
        updateMeans();
        updateVariances();
        recomputeMixtureComponents();
        updateMixtureWeights();
        updateTransitionMatrices();
    }

    private void copyVector(float[] fArr, float[] fArr2) {
        if (!$assertionsDisabled && fArr.length != fArr2.length) {
            throw new AssertionError();
        }
        System.arraycopy(fArr, 0, fArr2, 0, fArr.length);
    }

    private void updateMeans() {
        if (!$assertionsDisabled && this.meansPool.size() != this.meansBufferPool.size()) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.meansPool.size(); i++) {
            float[] fArr = this.meansPool.get(i);
            Buffer buffer = this.meansBufferPool.get(i);
            if (buffer.wasUsed()) {
                copyVector(buffer.getValues(), fArr);
            } else {
                logger.info("Senone " + i + " not used.");
            }
        }
    }

    private void updateVariances() {
        if (!$assertionsDisabled && this.variancePool.size() != this.varianceBufferPool.size()) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.variancePool.size(); i++) {
            float[] fArr = this.meansPool.get(i);
            float[] fArr2 = this.variancePool.get(i);
            Buffer buffer = this.varianceBufferPool.get(i);
            if (buffer.wasUsed()) {
                float[] values = buffer.getValues();
                if (!$assertionsDisabled && fArr.length != values.length) {
                    throw new AssertionError();
                }
                for (int i2 = 0; i2 < fArr.length; i2++) {
                    int i3 = i2;
                    values[i3] = values[i3] - (fArr[i2] * fArr[i2]);
                    if (values[i2] < this.varianceFloor) {
                        values[i2] = this.varianceFloor;
                    }
                }
                copyVector(values, fArr2);
            }
        }
    }

    private void recomputeMixtureComponents() {
        for (int i = 0; i < this.senonePool.size(); i++) {
            for (MixtureComponent mixtureComponent : ((GaussianMixture) this.senonePool.get(i)).getMixtureComponents()) {
                mixtureComponent.precomputeDistance();
            }
        }
    }

    private void updateMixtureWeights() {
        int statesNum = this.mixtureWeights.getStatesNum();
        int streamsNum = this.mixtureWeights.getStreamsNum();
        if (!$assertionsDisabled && statesNum * streamsNum != this.mixtureWeightsBufferPool.size()) {
            throw new AssertionError();
        }
        for (int i = 0; i < streamsNum; i++) {
            for (int i2 = 0; i2 < statesNum; i2++) {
                Buffer buffer = this.mixtureWeightsBufferPool.get((i * statesNum) + i2);
                if (buffer.wasUsed()) {
                    if (buffer.logFloor(this.logMixtureWeightFloor)) {
                        buffer.logNormalizeToSum(this.logMath);
                    }
                    this.mixtureWeights.put(i2, i, buffer.getValues());
                }
            }
        }
    }

    private void updateTransitionMatrices() {
        if (!$assertionsDisabled && this.matrixPool.size() != this.matrixBufferPool.size()) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.matrixPool.size(); i++) {
            float[][] fArr = this.matrixPool.get(i);
            Buffer[] bufferArr = this.matrixBufferPool.get(i);
            for (int i2 = 0; i2 < fArr.length; i2++) {
                Buffer buffer = bufferArr[i2];
                if (buffer.wasUsed()) {
                    for (int i3 = 0; i3 < fArr[i2].length; i3++) {
                        float value = buffer.getValue(i3);
                        if (value != -3.4028235E38f) {
                            if (!$assertionsDisabled && fArr[i2][i3] == -3.4028235E38f) {
                                throw new AssertionError();
                            }
                            if (value < this.logTransitionProbabilityFloor) {
                                buffer.setValue(i3, this.logTransitionProbabilityFloor);
                            }
                        }
                    }
                    buffer.logNormalizeToSum(this.logMath);
                    copyVector(buffer.getValues(), fArr[i2]);
                }
            }
        }
    }

    static {
        $assertionsDisabled = !HMMPoolManager.class.desiredAssertionStatus();
        logger = Logger.getLogger("edu.cmu.sphinx.linguist.acoustic.HMMPoolManager");
    }
}
