/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.siddhi.extension.kf;

import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.wso2.siddhi.core.config.ExecutionPlanContext;
import org.wso2.siddhi.core.exception.ExecutionPlanRuntimeException;
import org.wso2.siddhi.core.executor.ExpressionExecutor;
import org.wso2.siddhi.core.executor.function.FunctionExecutor;
import org.wso2.siddhi.query.api.definition.Attribute;
import org.wso2.siddhi.query.api.exception.ExecutionPlanValidationException;

public class KalmanFilter
extends FunctionExecutor {
    private double transition;
    private double measurementNoiseSD;
    private double prevEstimatedValue;
    private double variance;
    private RealMatrix measurementMatrixH = null;
    private RealMatrix varianceMatrixP;
    private RealMatrix prevMeasuredMatrix;
    private long prevTimestamp;
    Attribute.Type returnType = Attribute.Type.DOUBLE;

    public void start() {
    }

    public void stop() {
    }

    public Object[] currentState() {
        return new Object[]{this.transition, this.measurementNoiseSD, this.prevEstimatedValue, this.variance, this.measurementMatrixH, this.varianceMatrixP, this.prevMeasuredMatrix, this.prevTimestamp};
    }

    public void restoreState(Object[] state) {
        this.transition = (Double)state[0];
        this.measurementNoiseSD = (Double)state[1];
        this.prevEstimatedValue = (Double)state[2];
        this.variance = (Double)state[3];
        this.measurementMatrixH = (RealMatrix)state[4];
        this.varianceMatrixP = (RealMatrix)state[5];
        this.prevMeasuredMatrix = (RealMatrix)state[6];
        this.prevTimestamp = (Long)state[7];
    }

    protected void init(ExpressionExecutor[] attributeExpressionExecutors, ExecutionPlanContext executionPlanContext) {
        if (attributeExpressionExecutors.length != 1 && attributeExpressionExecutors.length != 2 && attributeExpressionExecutors.length != 4) {
            throw new ExecutionPlanValidationException("Invalid no of arguments passed to kf:kalmanFilter() function, required 1, 2 or 4, but found " + attributeExpressionExecutors.length);
        }
        if (attributeExpressionExecutors[0].getReturnType() != Attribute.Type.DOUBLE) {
            throw new ExecutionPlanValidationException("Invalid parameter type found for the first argument of kf:kalmanFilter() function, required " + Attribute.Type.DOUBLE + ", but found " + attributeExpressionExecutors[0].getReturnType().toString());
        }
        if ((attributeExpressionExecutors.length == 2 || attributeExpressionExecutors.length == 4) && attributeExpressionExecutors[1].getReturnType() != Attribute.Type.DOUBLE) {
            throw new ExecutionPlanValidationException("Invalid parameter type found for the second argument of kf:kalmanFilter() function, required " + Attribute.Type.DOUBLE + ", but found " + attributeExpressionExecutors[1].getReturnType().toString());
        }
        if (attributeExpressionExecutors.length == 4) {
            if (attributeExpressionExecutors[2].getReturnType() != Attribute.Type.DOUBLE) {
                throw new ExecutionPlanValidationException("Invalid parameter type found for the third argument of kf:kalmanFilter() function, required " + Attribute.Type.DOUBLE + ", but found " + attributeExpressionExecutors[1].getReturnType().toString());
            }
            if (attributeExpressionExecutors[3].getReturnType() != Attribute.Type.LONG) {
                throw new ExecutionPlanValidationException("Invalid parameter type found for the fourth argument of kf:kalmanFilter() function, required " + Attribute.Type.LONG + ", but found " + attributeExpressionExecutors[1].getReturnType().toString());
            }
        }
    }

    protected Object execute(Object[] data) {
        long timestampDiff;
        if (data[0] == null) {
            throw new ExecutionPlanRuntimeException("Invalid input given to kf:kalmanFilter() function. First argument should be a double");
        }
        if (data[1] == null) {
            throw new ExecutionPlanRuntimeException("Invalid input given to kf:kalmanFilter() function. Second argument should be a double");
        }
        if (data.length == 2) {
            double measuredValue = (Double)data[0];
            if (this.prevEstimatedValue == 0.0) {
                this.transition = 1.0;
                this.variance = 1000.0;
                this.measurementNoiseSD = (Double)data[1];
                this.prevEstimatedValue = measuredValue;
            }
            this.prevEstimatedValue = this.transition * this.prevEstimatedValue;
            double kalmanGain = this.variance / (this.variance + this.measurementNoiseSD);
            this.prevEstimatedValue += kalmanGain * (measuredValue - this.prevEstimatedValue);
            this.variance = (1.0 - kalmanGain) * this.variance;
            return this.prevEstimatedValue;
        }
        if (data[2] == null) {
            throw new ExecutionPlanRuntimeException("Invalid input given to kf:kalmanFilter() function. Third argument should be a double");
        }
        if (data[3] == null) {
            throw new ExecutionPlanRuntimeException("Invalid input given to kf:kalmanFilter() function. Fourth argument should be a long");
        }
        double measuredXValue = (Double)data[0];
        double measuredChangingRate = (Double)data[1];
        double measurementNoiseSD = (Double)data[2];
        long timestamp = (Long)data[3];
        double[][] measuredValues = new double[][]{{measuredXValue}, {measuredChangingRate}};
        if (this.measurementMatrixH == null) {
            timestampDiff = 1L;
            double[][] varianceValues = new double[][]{{1000.0, 0.0}, {0.0, 1000.0}};
            double[][] measurementValues = new double[][]{{1.0, 0.0}, {0.0, 1.0}};
            this.measurementMatrixH = MatrixUtils.createRealMatrix((double[][])measurementValues);
            this.varianceMatrixP = MatrixUtils.createRealMatrix((double[][])varianceValues);
            this.prevMeasuredMatrix = MatrixUtils.createRealMatrix((double[][])measuredValues);
        } else {
            timestampDiff = timestamp - this.prevTimestamp;
        }
        double[][] Rvalues = new double[][]{{measurementNoiseSD, 0.0}, {0.0, measurementNoiseSD}};
        RealMatrix rMatrix = MatrixUtils.createRealMatrix((double[][])Rvalues);
        double[][] transitionValues = new double[][]{{1.0, timestampDiff}, {0.0, 1.0}};
        RealMatrix transitionMatrixA = MatrixUtils.createRealMatrix((double[][])transitionValues);
        RealMatrix measuredMatrixX = MatrixUtils.createRealMatrix((double[][])measuredValues);
        this.prevMeasuredMatrix = transitionMatrixA.multiply(this.prevMeasuredMatrix);
        this.varianceMatrixP = transitionMatrixA.multiply(this.varianceMatrixP).multiply(transitionMatrixA.transpose());
        RealMatrix S = this.measurementMatrixH.multiply(this.varianceMatrixP).multiply(this.measurementMatrixH.transpose()).add(rMatrix);
        RealMatrix S_1 = new LUDecomposition(S).getSolver().getInverse();
        RealMatrix kalmanGainMatrix = this.varianceMatrixP.multiply(this.measurementMatrixH.transpose()).multiply(S_1);
        this.prevMeasuredMatrix = this.prevMeasuredMatrix.add(kalmanGainMatrix.multiply(measuredMatrixX.subtract(this.measurementMatrixH.multiply(this.prevMeasuredMatrix))));
        this.varianceMatrixP = this.varianceMatrixP.subtract(kalmanGainMatrix.multiply(this.measurementMatrixH).multiply(this.varianceMatrixP));
        this.prevTimestamp = timestamp;
        return this.prevMeasuredMatrix.getRow(0)[0];
    }

    protected Object execute(Object data) {
        if (data == null) {
            throw new ExecutionPlanRuntimeException("Invalid input given to kf:kalmanFilter() function. Argument should be a double");
        }
        double measuredValue = (Double)data;
        if (this.transition == 0.0) {
            this.transition = 1.0;
            this.variance = 1000.0;
            this.measurementNoiseSD = 0.001;
            this.prevEstimatedValue = measuredValue;
        }
        this.prevEstimatedValue = this.transition * this.prevEstimatedValue;
        double kalmanGain = this.variance / (this.variance + this.measurementNoiseSD);
        this.prevEstimatedValue += kalmanGain * (measuredValue - this.prevEstimatedValue);
        this.variance = (1.0 - kalmanGain) * this.variance;
        return this.prevEstimatedValue;
    }

    public Attribute.Type getReturnType() {
        return this.returnType;
    }
}

