/*
 * Decompiled with CFR 0.152.
 */
package org.jetbrains.completion.full.line.local.generation.search;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.SourceDebugExtension;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.completion.full.line.local.generation.LapTimer;
import org.jetbrains.completion.full.line.local.generation.UtilsKt;
import org.jetbrains.completion.full.line.local.generation.generation.SearchState;
import org.jetbrains.completion.full.line.local.generation.search.Search;

@Metadata(mv={2, 2, 0}, k=1, xi=48, d1={"\u00008\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0011\n\u0002\u0010\u0013\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0006\n\u0002\b\u0003\u0018\u00002\u00020\u0001B\u0017\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u00a2\u0006\u0004\b\u0005\u0010\u0006J+\u0010\u0007\u001a\u00020\b2\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\u000b0\n2\u0006\u0010\f\u001a\u00020\r2\u0006\u0010\u000e\u001a\u00020\u000fH\u0016\u00a2\u0006\u0002\u0010\u0010J!\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\u00120\n2\f\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u000b0\nH\u0002\u00a2\u0006\u0002\u0010\u0014R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004\u00a2\u0006\u0002\n\u0000R\u000e\u0010\u0004\u001a\u00020\u0003X\u0082\u0004\u00a2\u0006\u0002\n\u0000\u00a8\u0006\u0015"}, d2={"Lorg/jetbrains/completion/full/line/local/generation/search/FullLineBeamSearch;", "Lorg/jetbrains/completion/full/line/local/generation/search/Search;", "searchSize", "", "vocabSize", "<init>", "(II)V", "step", "Lorg/jetbrains/completion/full/line/local/generation/search/Search$StepResult;", "stepLogProbs", "", "", "searchState", "Lorg/jetbrains/completion/full/line/local/generation/generation/SearchState;", "timer", "Lorg/jetbrains/completion/full/line/local/generation/LapTimer;", "([[DLorg/jetbrains/completion/full/line/local/generation/generation/SearchState;Lorg/jetbrains/completion/full/line/local/generation/LapTimer;)Lorg/jetbrains/completion/full/line/local/generation/search/Search$StepResult;", "entropy", "", "logProbs", "([[D)[Ljava/lang/Double;", "intellij.fullLine.local"})
@SourceDebugExtension(value={"SMAP\nFullLineBeamSearch.kt\nKotlin\n*S Kotlin\n*F\n+ 1 FullLineBeamSearch.kt\norg/jetbrains/completion/full/line/local/generation/search/FullLineBeamSearch\n+ 2 fake.kt\nkotlin/jvm/internal/FakeKt\n+ 3 _Arrays.kt\nkotlin/collections/ArraysKt___ArraysKt\n+ 4 ArraysJVM.kt\nkotlin/collections/ArraysKt__ArraysJVMKt\n*L\n1#1,58:1\n1#2:59\n12851#3,3:60\n11228#3:63\n11563#3,3:64\n11228#3:71\n11563#3,3:72\n37#4:67\n36#4,3:68\n37#4:75\n36#4,3:76\n*S KotlinDebug\n*F\n+ 1 FullLineBeamSearch.kt\norg/jetbrains/completion/full/line/local/generation/search/FullLineBeamSearch\n*L\n23#1:60,3\n41#1:63\n41#1:64,3\n55#1:71\n55#1:72,3\n41#1:67\n41#1:68,3\n55#1:75\n55#1:76,3\n*E\n"})
public final class FullLineBeamSearch
implements Search {
    private final int searchSize;
    private final int vocabSize;

    public FullLineBeamSearch(int searchSize, int vocabSize) {
        this.searchSize = searchSize;
        this.vocabSize = vocabSize;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    @NotNull
    public Search.StepResult step(@NotNull double[][] stepLogProbs, @NotNull SearchState searchState, @NotNull LapTimer timer) {
        Pair pair;
        int n;
        Intrinsics.checkNotNullParameter((Object)stepLogProbs, (String)"stepLogProbs");
        Intrinsics.checkNotNullParameter((Object)searchState, (String)"searchState");
        Intrinsics.checkNotNullParameter((Object)timer, (String)"timer");
        Object[] objectArray = (Object[])stepLogProbs;
        int n2 = 0;
        for (Object object : objectArray) {
            void it;
            double[] dArray = (double[])object;
            int n3 = n2;
            n = 0;
            int n4 = ((void)it).length;
            n2 = n3 + n4;
        }
        int logProbsLinearSize = n2;
        double[] newLogProbs = new double[logProbsLinearSize];
        int offset = 0;
        int n5 = ((Object[])stepLogProbs).length;
        for (int i = 0; i < n5; ++i) {
            double[] probs = stepLogProbs[i];
            double score = ((Number)searchState.getHypothesesScores().get(i)).doubleValue();
            for (double value22 : probs) {
                newLogProbs[offset] = value22 + score;
                ++offset;
            }
        }
        timer.lap("BeamSearch step: add cur logProbs to past");
        double[] $this$count$iv = newLogProbs;
        boolean $i$f$count = false;
        int count$iv = 0;
        int n6 = $this$count$iv.length;
        for (n = 0; n < n6; ++n) {
            double element$iv;
            double it = element$iv = $this$count$iv[n];
            boolean bl = false;
            if (!(it > Double.NEGATIVE_INFINITY)) continue;
            ++count$iv;
        }
        int notMaskedCount = count$iv;
        int newNumSamples = Math.min(notMaskedCount, Math.min(this.searchSize, logProbsLinearSize));
        timer.lap("BeamSearch step: count numSamples");
        int[] samples = UtilsKt.topk1d(newLogProbs, newNumSamples);
        timer.lap("BeamSearch step: topk1d");
        double[] sampleScores = UtilsKt.sliceArray(newLogProbs, samples);
        timer.lap("BeamSearch step: slice scores");
        n6 = 0;
        int n7 = samples.length;
        int[] value22 = new int[n7];
        while (n6 < n7) {
            int it = n6++;
            value22[it] = Math.floorDiv(samples[it], this.vocabSize);
        }
        int[] stepMask = value22;
        n6 = 0;
        n7 = samples.length;
        value22 = new int[n7];
        while (n6 < n7) {
            int it = n6++;
            value22[it] = Math.floorMod(samples[it], this.vocabSize);
        }
        samples = value22;
        n7 = 0;
        int value22 = stepMask.length;
        double[] it = new double[value22];
        while (n7 < value22) {
            int n8 = n7++;
            it[n8] = stepLogProbs[stepMask[n8]][samples[n8]];
        }
        double[] probabilities = it;
        if (notMaskedCount != newLogProbs.length) {
            void $this$toTypedArray$iv2;
            void $this$mapTo$iv$iv;
            Object $this$map$iv = (Object[])stepLogProbs;
            boolean $i$f$map = false;
            Object[] bl = $this$map$iv;
            Collection destination$iv$iv = new ArrayList(((Object[])$this$map$iv).length);
            boolean $i$f$mapTo = false;
            for (void item$iv$iv : $this$mapTo$iv$iv) {
                void it2;
                double[] dArray = (double[])item$iv$iv;
                Collection collection = destination$iv$iv;
                boolean bl2 = false;
                void v0 = it2;
                double[] dArray2 = Arrays.copyOf((double[])v0, ((void)v0).length);
                Intrinsics.checkNotNullExpressionValue((Object)dArray2, (String)"copyOf(...)");
                collection.add(dArray2);
            }
            $this$map$iv = (List)destination$iv$iv;
            int $i$f$toTypedArray = 0;
            void thisCollection$iv = $this$toTypedArray$iv2;
            double[][] stepLogProbsCopy = (double[][])thisCollection$iv.toArray((T[])new double[0][]);
            UtilsKt.logSoftmax(stepLogProbsCopy, true);
            $this$toTypedArray$iv2 = 0;
            $i$f$toTypedArray = stepMask.length;
            double[] dArray = new double[$i$f$toTypedArray];
            while ($this$toTypedArray$iv2 < $i$f$toTypedArray) {
                int n9 = $this$toTypedArray$iv2++;
                dArray[n9] = stepLogProbsCopy[stepMask[n9]][samples[n9]];
            }
            double[] dArray3 = dArray;
            $this$toTypedArray$iv2 = 0;
            $i$f$toTypedArray = stepMask.length;
            dArray = new double[$i$f$toTypedArray];
            var25_21 = dArray3;
            while ($this$toTypedArray$iv2 < $i$f$toTypedArray) {
                int n10 = $this$toTypedArray$iv2++;
                dArray[n10] = this.entropy(stepLogProbsCopy)[stepMask[n10]];
            }
            pair = TuplesKt.to((Object)var25_21, (Object)dArray);
        } else {
            int stepLogProbsCopy = 0;
            $this$toTypedArray$iv2 = stepMask.length;
            double[] $i$f$toTypedArray = new double[$this$toTypedArray$iv2];
            var25_21 = probabilities;
            while (stepLogProbsCopy < $this$toTypedArray$iv2) {
                int n11 = stepLogProbsCopy++;
                $i$f$toTypedArray[n11] = this.entropy(stepLogProbs)[stepMask[n11]];
            }
            pair = TuplesKt.to((Object)var25_21, (Object)$i$f$toTypedArray);
        }
        Pair pair2 = pair;
        double[] normalizedProbabilities = (double[])pair2.component1();
        double[] entropies = (double[])pair2.component2();
        int n12 = 0;
        int n13 = stepMask.length;
        double[] dArray = new double[n13];
        while (n12 < n13) {
            int n14 = n12++;
            dArray[n14] = normalizedProbabilities[n14] + ((Number)searchState.getHypothesesNormalizedScores().get(stepMask[n14])).doubleValue();
        }
        double[] normalizedScores = dArray;
        timer.lap("BeamSearch step: divmod samples");
        return new Search.StepResult(stepMask, samples, probabilities, sampleScores, normalizedProbabilities, normalizedScores, entropies);
    }

    /*
     * WARNING - void declaration
     */
    private final Double[] entropy(double[][] logProbs) {
        void $this$mapTo$iv$iv;
        Object[] $this$map$iv = (Object[])logProbs;
        boolean $i$f$map = false;
        Object[] objectArray = $this$map$iv;
        Collection destination$iv$iv = new ArrayList($this$map$iv.length);
        boolean $i$f$mapTo = false;
        for (void item$iv$iv : $this$mapTo$iv$iv) {
            void it;
            double[] dArray = (double[])item$iv$iv;
            Collection collection = destination$iv$iv;
            boolean bl = false;
            void var12_12 = it;
            double d = 0.0;
            int n = ((void)var12_12).length;
            for (int i = 0; i < n; ++i) {
                void it2;
                void var17_16;
                void var19_17 = var17_16 = var12_12[i];
                double d2 = d;
                boolean bl2 = false;
                double d3 = !(it2 == Double.NEGATIVE_INFINITY) ? -it2 * UtilsKt.fastExp((double)it2) : 0.0;
                d = d2 + d3;
            }
            collection.add(d);
        }
        Collection $this$toTypedArray$iv = (List)destination$iv$iv;
        boolean $i$f$toTypedArray = false;
        Collection thisCollection$iv = $this$toTypedArray$iv;
        return thisCollection$iv.toArray(new Double[0]);
    }
}

