/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.zoo.tabular.softmax_regression;

import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisFlower;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class IrisClassificationTranslatorFactory
implements TranslatorFactory {
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return Collections.singleton(new Pair(IrisFlower.class, Classifications.class));
    }

    public <I, O> Translator<I, O> newInstance(Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
        if (!this.isSupported(input, output)) {
            throw new IllegalArgumentException("Unsupported input/output types.");
        }
        return new IrisTranslator();
    }

    private static final class IrisTranslator
    implements NoBatchifyTranslator<IrisFlower, Classifications> {
        private List<String> synset = Arrays.asList("setosa", "versicolor", "virginica");

        public NDList processInput(TranslatorContext ctx, IrisFlower input) {
            float[] data = new float[]{input.getSepalLength(), input.getSepalWidth(), input.getPetalLength(), input.getPetalWidth()};
            NDArray array = ctx.getNDManager().create(data, new Shape(new long[]{1L, 4L}));
            return new NDList(new NDArray[]{array});
        }

        public Classifications processOutput(TranslatorContext ctx, NDList list) {
            float[] data = ((NDArray)list.get(1)).toFloatArray();
            ArrayList<Double> probabilities = new ArrayList<Double>(data.length);
            for (float f : data) {
                probabilities.add(Double.valueOf(f));
            }
            return new Classifications(this.synset, probabilities);
        }
    }
}

