/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.io.recovery;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Predicate;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.runtime.plugable.DeserializationDelegate;
import org.apache.flink.shaded.guava30.com.google.common.collect.Maps;
import org.apache.flink.streaming.runtime.io.AbstractStreamTaskNetworkInput;
import org.apache.flink.streaming.runtime.io.DataInputStatus;
import org.apache.flink.streaming.runtime.io.RecoverableStreamTaskInput;
import org.apache.flink.streaming.runtime.io.StreamTaskInput;
import org.apache.flink.streaming.runtime.io.StreamTaskNetworkInput;
import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointedInputGate;
import org.apache.flink.streaming.runtime.io.recovery.DemultiplexingRecordDeserializer;
import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.watermarkstatus.StatusWatermarkValve;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
public final class RescalingStreamTaskNetworkInput<T>
extends AbstractStreamTaskNetworkInput<T, DemultiplexingRecordDeserializer<T>>
implements RecoverableStreamTaskInput<T> {
    private static final Logger LOG = LoggerFactory.getLogger(RescalingStreamTaskNetworkInput.class);
    private final IOManager ioManager;

    public RescalingStreamTaskNetworkInput(CheckpointedInputGate checkpointedInputGate, TypeSerializer<T> inputSerializer, IOManager ioManager, StatusWatermarkValve statusWatermarkValve, int inputIndex, InflightDataRescalingDescriptor inflightDataRescalingDescriptor, Function<Integer, StreamPartitioner<?>> gatePartitioners, TaskInfo taskInfo) {
        super(checkpointedInputGate, inputSerializer, statusWatermarkValve, inputIndex, RescalingStreamTaskNetworkInput.getRecordDeserializers(checkpointedInputGate, inputSerializer, ioManager, inflightDataRescalingDescriptor, gatePartitioners, taskInfo));
        this.ioManager = ioManager;
        LOG.info("Created demultiplexer for input {} from {}", (Object)inputIndex, (Object)inflightDataRescalingDescriptor);
    }

    private static <T> Map<InputChannelInfo, DemultiplexingRecordDeserializer<T>> getRecordDeserializers(CheckpointedInputGate checkpointedInputGate, TypeSerializer<T> inputSerializer, IOManager ioManager, InflightDataRescalingDescriptor rescalingDescriptor, Function<Integer, StreamPartitioner<?>> gatePartitioners, TaskInfo taskInfo) {
        RecordFilterFactory<T> recordFilterFactory = new RecordFilterFactory<T>(taskInfo.getIndexOfThisSubtask(), inputSerializer, taskInfo.getNumberOfParallelSubtasks(), gatePartitioners, taskInfo.getMaxNumberOfParallelSubtasks());
        DeserializerFactory deserializerFactory = new DeserializerFactory(ioManager);
        HashMap deserializers = Maps.newHashMapWithExpectedSize((int)checkpointedInputGate.getChannelInfos().size());
        for (InputChannelInfo channelInfo : checkpointedInputGate.getChannelInfos()) {
            deserializers.put(channelInfo, DemultiplexingRecordDeserializer.create(channelInfo, rescalingDescriptor, deserializerFactory, recordFilterFactory));
        }
        return deserializers;
    }

    @Override
    public StreamTaskInput<T> finishRecovery() throws IOException {
        Preconditions.checkState((!this.recordDeserializers.values().stream().anyMatch(DemultiplexingRecordDeserializer::hasPartialData) ? 1 : 0) != 0, (Object)"Not all data has been fully consumed");
        this.close();
        return new StreamTaskNetworkInput(this.checkpointedInputGate, this.inputSerializer, this.ioManager, this.statusWatermarkValve, this.inputIndex);
    }

    @Override
    protected DemultiplexingRecordDeserializer<T> getActiveSerializer(InputChannelInfo channelInfo) {
        DemultiplexingRecordDeserializer deserialier = (DemultiplexingRecordDeserializer)super.getActiveSerializer(channelInfo);
        if (!deserialier.hasMappings()) {
            throw new IllegalStateException("Channel " + channelInfo + " should not receive data during recovery.");
        }
        return deserialier;
    }

    @Override
    protected DataInputStatus processEvent(BufferOrEvent bufferOrEvent) {
        AbstractEvent event = bufferOrEvent.getEvent();
        if (event instanceof SubtaskConnectionDescriptor) {
            this.getActiveSerializer(bufferOrEvent.getChannelInfo()).select((SubtaskConnectionDescriptor)event);
            return DataInputStatus.MORE_AVAILABLE;
        }
        return super.processEvent(bufferOrEvent);
    }

    @Override
    public CompletableFuture<Void> prepareSnapshot(ChannelStateWriter channelStateWriter, long checkpointId) throws CheckpointException {
        throw new CheckpointException(CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY);
    }

    static class DeserializerFactory
    implements Function<Integer, RecordDeserializer<DeserializationDelegate<StreamElement>>> {
        private final IOManager ioManager;

        public DeserializerFactory(IOManager ioManager) {
            this.ioManager = ioManager;
        }

        @Override
        public RecordDeserializer<DeserializationDelegate<StreamElement>> apply(Integer totalChannels) {
            return new SpillingAdaptiveSpanningRecordDeserializer(this.ioManager.getSpillingDirectoriesPaths(), 0x500000 / totalChannels, 0x200000 / totalChannels);
        }
    }

    static class RecordFilterFactory<T>
    implements Function<InputChannelInfo, Predicate<StreamRecord<T>>> {
        private final Map<Integer, StreamPartitioner<T>> partitionerCache = new HashMap<Integer, StreamPartitioner<T>>(1);
        private final Function<Integer, StreamPartitioner<?>> gatePartitioners;
        private final TypeSerializer<T> inputSerializer;
        private final int numberOfChannels;
        private final int subtaskIndex;
        private final int maxParallelism;

        public RecordFilterFactory(int subtaskIndex, TypeSerializer<T> inputSerializer, int numberOfChannels, Function<Integer, StreamPartitioner<?>> gatePartitioners, int maxParallelism) {
            this.gatePartitioners = gatePartitioners;
            this.inputSerializer = inputSerializer;
            this.numberOfChannels = numberOfChannels;
            this.subtaskIndex = subtaskIndex;
            this.maxParallelism = maxParallelism;
        }

        @Override
        public Predicate<StreamRecord<T>> apply(InputChannelInfo channelInfo) {
            StreamPartitioner partitioner = this.partitionerCache.computeIfAbsent(channelInfo.getGateIdx(), this::createPartitioner);
            return new RecordFilter<T>(partitioner.copy(), this.inputSerializer, this.subtaskIndex);
        }

        private StreamPartitioner<T> createPartitioner(Integer index) {
            StreamPartitioner<?> partitioner = this.gatePartitioners.apply(index);
            partitioner.setup(this.numberOfChannels);
            if (partitioner instanceof ConfigurableStreamPartitioner) {
                ((ConfigurableStreamPartitioner)((Object)partitioner)).configure(this.maxParallelism);
            }
            return partitioner;
        }
    }
}

