/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.types.inference;

import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.TableSemantics;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.StateTypeStrategy;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategy;
import org.apache.flink.table.types.inference.utils.CastCallContext;
import org.apache.flink.table.types.inference.utils.UnknownCallContext;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;

@Internal
public final class TypeInferenceUtil {
    public static Result runTypeInference(TypeInference typeInference, CallContext callContext, @Nullable SurroundingInfo surroundingInfo) {
        try {
            return TypeInferenceUtil.runTypeInferenceInternal(typeInference, callContext, surroundingInfo);
        }
        catch (ValidationException e) {
            throw TypeInferenceUtil.createInvalidCallException(callContext, e);
        }
        catch (Throwable t) {
            throw TypeInferenceUtil.createUnexpectedException(callContext, t);
        }
    }

    public static CallContext castArguments(TypeInference typeInference, CallContext callContext, @Nullable DataType outputType) {
        return TypeInferenceUtil.castArguments(typeInference, callContext, outputType, true);
    }

    private static CallContext castArguments(TypeInference typeInference, CallContext callContext, @Nullable DataType outputType, boolean throwOnInferInputFailure) {
        List<DataType> actualTypes = callContext.getArgumentDataTypes();
        typeInference.getStaticArguments().ifPresent(staticArgs -> {
            if (actualTypes.size() != staticArgs.size()) {
                throw new ValidationException(String.format("Invalid number of arguments. %d arguments expected after argument expansion but %d passed.", staticArgs.size(), actualTypes.size()));
            }
        });
        CastCallContext castCallContext = TypeInferenceUtil.inferInputTypes(typeInference, callContext, outputType, throwOnInferInputFailure);
        List<DataType> expectedTypes = castCallContext.getArgumentDataTypes();
        for (int pos = 0; pos < actualTypes.size(); ++pos) {
            DataType expectedType = expectedTypes.get(pos);
            DataType actualType = actualTypes.get(pos);
            if (LogicalTypeCasts.supportsImplicitCast(actualType.getLogicalType(), expectedType.getLogicalType())) continue;
            if (!throwOnInferInputFailure) {
                return callContext;
            }
            throw new ValidationException(String.format("Invalid argument type at position %d. Data type %s expected but %s passed.", pos, expectedType, actualType));
        }
        return castCallContext;
    }

    public static DataType inferOutputType(CallContext callContext, TypeStrategy outputTypeStrategy) {
        Optional<DataType> potentialOutputType = outputTypeStrategy.inferType(callContext);
        if (potentialOutputType.isEmpty()) {
            throw new ValidationException("Could not infer an output type for the given arguments.");
        }
        DataType outputType = potentialOutputType.get();
        if (TypeInferenceUtil.isUnknown(outputType)) {
            throw new ValidationException("Could not infer an output type for the given arguments. Untyped NULL received.");
        }
        return outputType;
    }

    public static LinkedHashMap<String, StateInfo> inferStateInfos(CallContext callContext, LinkedHashMap<String, StateTypeStrategy> stateTypeStrategies) {
        return stateTypeStrategies.entrySet().stream().map(e -> Map.entry((String)e.getKey(), TypeInferenceUtil.inferStateInfo(callContext, (String)e.getKey(), (StateTypeStrategy)e.getValue()))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> y, LinkedHashMap::new));
    }

    public static String generateSignature(TypeInference typeInference, String name, FunctionDefinition definition) {
        List staticArguments = typeInference.getStaticArguments().orElse(null);
        if (staticArguments != null) {
            return TypeInferenceUtil.formatStaticArguments(name, staticArguments);
        }
        return typeInference.getInputTypeStrategy().getExpectedSignatures(definition).stream().map(s -> TypeInferenceUtil.formatSignature(name, s)).collect(Collectors.joining("\n"));
    }

    public static ValidationException createInvalidInputException(TypeInference typeInference, CallContext callContext, ValidationException cause) {
        return new ValidationException(String.format("Invalid input arguments. Expected signatures are:\n%s", TypeInferenceUtil.generateSignature(typeInference, callContext.getName(), callContext.getFunctionDefinition())), cause);
    }

    public static ValidationException createInvalidCallException(CallContext callContext, ValidationException cause) {
        return new ValidationException(String.format("Invalid function call:\n%s(%s)", callContext.getName(), callContext.getArgumentDataTypes().stream().map(DataType::toString).collect(Collectors.joining(", "))), cause);
    }

    public static TableException createUnexpectedException(CallContext callContext, Throwable cause) {
        return new TableException(String.format("Unexpected error in type inference logic of function '%s'. This is a bug.", callContext.getName()), cause);
    }

    public static boolean validateArgumentCount(ArgumentCount argumentCount, int actualCount, boolean throwOnFailure) {
        int minCount = argumentCount.getMinCount().orElse(0);
        if (actualCount < minCount) {
            if (throwOnFailure) {
                throw new ValidationException(String.format("Invalid number of arguments. At least %d arguments expected but %d passed.", minCount, actualCount));
            }
            return false;
        }
        int maxCount = argumentCount.getMaxCount().orElse(Integer.MAX_VALUE);
        if (actualCount > maxCount) {
            if (throwOnFailure) {
                throw new ValidationException(String.format("Invalid number of arguments. At most %d arguments expected but %d passed.", maxCount, actualCount));
            }
            return false;
        }
        if (!argumentCount.isValidCount(actualCount)) {
            if (throwOnFailure) {
                throw new ValidationException(String.format("Invalid number of arguments. %d arguments passed.", actualCount));
            }
            return false;
        }
        return true;
    }

    private static Result runTypeInferenceInternal(TypeInference typeInference, CallContext callContext, @Nullable SurroundingInfo surroundingInfo) {
        CallContext adaptedCallContext;
        DataType outputType;
        try {
            TypeInferenceUtil.validateArgumentCount(typeInference.getInputTypeStrategy().getArgumentCount(), callContext.getArgumentDataTypes().size(), true);
        }
        catch (ValidationException e) {
            throw TypeInferenceUtil.createInvalidInputException(typeInference, callContext, e);
        }
        try {
            outputType = surroundingInfo != null ? (DataType)surroundingInfo.inferOutputType(callContext.getDataTypeFactory()).orElse(null) : null;
            adaptedCallContext = TypeInferenceUtil.castArguments(typeInference, callContext, outputType);
        }
        catch (ValidationException e) {
            throw TypeInferenceUtil.createInvalidInputException(typeInference, callContext, e);
        }
        outputType = TypeInferenceUtil.inferOutputType(adaptedCallContext, typeInference.getOutputTypeStrategy());
        LinkedHashMap<String, StateInfo> stateInfos = TypeInferenceUtil.inferStateInfos(adaptedCallContext, typeInference.getStateTypeStrategies());
        return new Result(adaptedCallContext.getArgumentDataTypes(), stateInfos, outputType);
    }

    private static String formatStaticArguments(String name, List<StaticArgument> staticArguments) {
        String arguments = staticArguments.stream().map(StaticArgument::toString).collect(Collectors.joining(", "));
        return String.format("%s(%s)", name, arguments);
    }

    private static String formatSignature(String name, Signature s) {
        String arguments = s.getArguments().stream().map(TypeInferenceUtil::formatArgument).collect(Collectors.joining(", "));
        return String.format("%s(%s)", name, arguments);
    }

    private static String formatArgument(Signature.Argument arg) {
        StringBuilder stringBuilder = new StringBuilder();
        arg.getName().ifPresent(n -> stringBuilder.append((String)n).append(" "));
        stringBuilder.append(arg.getType());
        return stringBuilder.toString();
    }

    private static CastCallContext inferInputTypes(TypeInference typeInference, CallContext callContext, @Nullable DataType outputType, boolean throwOnFailure) {
        List inferredDataTypes;
        CastCallContext castCallContext = new CastCallContext(callContext, outputType);
        List staticArgs = typeInference.getStaticArguments().orElse(null);
        if (staticArgs != null) {
            List<DataType> fromStaticArgs = IntStream.range(0, staticArgs.size()).mapToObj(pos -> {
                StaticArgument expectedArg = (StaticArgument)staticArgs.get(pos);
                if (expectedArg.is(StaticArgumentTrait.TABLE)) {
                    TableSemantics semantics = callContext.getTableSemantics(pos).orElse(null);
                    if (semantics == null) {
                        if (throwOnFailure) {
                            throw new ValidationException(String.format("Invalid argument value. Argument '%s' expects a table to be passed.", expectedArg.getName()));
                        }
                        return null;
                    }
                    return semantics.dataType();
                }
                return expectedArg.getDataType().orElse(null);
            }).collect(Collectors.toList());
            if (fromStaticArgs.stream().allMatch(Objects::nonNull)) {
                castCallContext.setExpectedArguments(fromStaticArgs);
            } else if (throwOnFailure) {
                throw new ValidationException("Invalid input arguments.");
            }
        }
        if ((inferredDataTypes = (List)typeInference.getInputTypeStrategy().inferInputTypes(castCallContext, throwOnFailure).orElse(null)) != null) {
            castCallContext.setExpectedArguments(inferredDataTypes);
        } else if (throwOnFailure) {
            throw new ValidationException("Invalid input arguments.");
        }
        return castCallContext;
    }

    private static StateInfo inferStateInfo(CallContext callContext, String name, StateTypeStrategy stateTypeStrategy) {
        DataType stateType = stateTypeStrategy.inferType(callContext).orElse(null);
        if (stateType == null || TypeInferenceUtil.isUnknown(stateType)) {
            String errorMessage = name.equals("acc") ? "Could not infer an accumulator type for the given arguments." : String.format("Could not infer a data type for state entry '%s'.", name);
            throw new ValidationException(errorMessage);
        }
        Duration ttl = stateTypeStrategy.getTimeToLive(callContext).orElse(null);
        return new StateInfo(stateType, ttl);
    }

    private static boolean isUnknown(DataType dataType) {
        return dataType.getLogicalType().is(LogicalTypeRoot.NULL);
    }

    private TypeInferenceUtil() {
    }

    @Internal
    public static final class StateInfo {
        private final DataType dataType;
        @Nullable
        private final Duration timeToLive;

        private StateInfo(DataType dataType, @Nullable Duration timeToLive) {
            this.dataType = dataType;
            this.timeToLive = timeToLive;
        }

        public DataType getDataType() {
            return this.dataType;
        }

        public Optional<Duration> getTimeToLive() {
            return Optional.ofNullable(this.timeToLive);
        }
    }

    @Internal
    public static final class Result {
        private final List<DataType> expectedArgumentTypes;
        private final LinkedHashMap<String, StateInfo> stateInfos;
        private final DataType outputDataType;

        public Result(List<DataType> expectedArgumentTypes, LinkedHashMap<String, StateInfo> stateInfos, DataType outputDataType) {
            this.expectedArgumentTypes = expectedArgumentTypes;
            this.stateInfos = stateInfos;
            this.outputDataType = outputDataType;
        }

        public List<DataType> getExpectedArgumentTypes() {
            return this.expectedArgumentTypes;
        }

        public LinkedHashMap<String, StateInfo> getStateInfos() {
            return this.stateInfos;
        }

        public DataType getOutputDataType() {
            return this.outputDataType;
        }
    }

    @Internal
    public static interface SurroundingInfo {
        public static SurroundingInfo of(String name, FunctionDefinition functionDefinition, TypeInference typeInference, int argumentCount, int innerCallPosition, boolean isGroupedAggregation) {
            return typeFactory -> {
                boolean isValidCount = TypeInferenceUtil.validateArgumentCount(typeInference.getInputTypeStrategy().getArgumentCount(), argumentCount, false);
                if (!isValidCount) {
                    return Optional.empty();
                }
                UnknownCallContext callContext = new UnknownCallContext(typeFactory, name, functionDefinition, argumentCount, isGroupedAggregation);
                CallContext adaptedContext = TypeInferenceUtil.castArguments(typeInference, callContext, null, false);
                return typeInference.getInputTypeStrategy().inferInputTypes(adaptedContext, false).map(dataTypes -> (DataType)dataTypes.get(innerCallPosition));
            };
        }

        public static SurroundingInfo of(DataType dataType) {
            return typeFactory -> Optional.of(dataType);
        }

        public Optional<DataType> inferOutputType(DataTypeFactory var1);
    }
}

