/*
 * Decompiled with CFR 0.152.
 */
package li.cil.oc2.common.inet;

import java.nio.ByteBuffer;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.annotation.Nullable;
import li.cil.oc2.api.inet.TransportMessage;
import li.cil.oc2.api.inet.layer.SessionLayer;
import li.cil.oc2.api.inet.layer.TransportLayer;
import li.cil.oc2.api.inet.session.DatagramSession;
import li.cil.oc2.api.inet.session.EchoSession;
import li.cil.oc2.api.inet.session.Session;
import li.cil.oc2.api.inet.session.StreamSession;
import li.cil.oc2.common.Config;
import li.cil.oc2.common.inet.DatagramSessionBase;
import li.cil.oc2.common.inet.DatagramSessionDiscriminator;
import li.cil.oc2.common.inet.DatagramSessionImpl;
import li.cil.oc2.common.inet.EchoSessionDiscriminator;
import li.cil.oc2.common.inet.EchoSessionImpl;
import li.cil.oc2.common.inet.InetUtils;
import li.cil.oc2.common.inet.SessionActions;
import li.cil.oc2.common.inet.SessionBase;
import li.cil.oc2.common.inet.SessionDiscriminator;
import li.cil.oc2.common.inet.StreamSessionDiscriminator;
import li.cil.oc2.common.inet.StreamSessionImpl;
import net.minecraft.nbt.CompoundTag;
import net.minecraft.nbt.Tag;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class DefaultTransportLayer
implements TransportLayer {
    private static final Logger LOGGER = LogManager.getLogger();
    private static final byte ICMP_TYPE_ECHO_REPLY = 0;
    private static final byte ICMP_TYPE_ECHO_REQUEST = 8;
    private static final byte ICMP_TYPE_ECHO_UNREACHABLE = 3;
    private static final byte ICMP_CODE_ECHO_UNREACHABLE_PORT = 3;
    private static final byte ICMP_CODE_ECHO_UNREACHABLE_PROHIBITED = 13;
    private static final short PORT_ECHO = 7;
    private static final int ICMP_HEADER_SIZE = 8;
    private static final int UDP_HEADER_SIZE = 8;
    private static final int MIN_TCP_HEADER_SIZE = 20;
    private static int allSessionCount = 0;
    private final SessionLayer sessionLayer;
    private final SessionReceiver receiver = new SessionReceiver();
    private final NavigableMap<Instant, SessionBase> expirationQueue = new TreeMap<Instant, SessionBase>();
    private StreamSessionImpl streamToAck = null;
    private final Map<SessionDiscriminator<?>, SessionBase> sessions = new HashMap();
    private ICMPReply icmpReply = null;
    private StreamSessionImpl rejectedStream = null;

    public DefaultTransportLayer(SessionLayer sessionLayer) {
        this.sessionLayer = sessionLayer;
    }

    private <T> void processExpirationQueue(NavigableMap<Instant, T> queue, Consumer<T> action) {
        if (queue.isEmpty()) {
            return;
        }
        Instant expireTime = Instant.now().minus(Config.defaultSessionLifetimeMs, ChronoUnit.MILLIS);
        Iterator<Instant> iterator = queue.navigableKeySet().iterator();
        while (iterator.hasNext()) {
            Instant time = iterator.next();
            if (time.compareTo(expireTime) < 0) {
                Object value = queue.get(time);
                iterator.remove();
                action.accept(value);
                continue;
            }
            return;
        }
    }

    @Nullable
    private StreamSessionImpl getNextStreamForRetransmission() {
        if (this.expirationQueue.isEmpty()) {
            return null;
        }
        Instant retransmissionTime = Instant.now().minus(Config.tcpRetransmissionTimeoutMs, ChronoUnit.MILLIS);
        for (Instant time : this.expirationQueue.navigableKeySet()) {
            StreamSessionImpl stream;
            if (time.compareTo(retransmissionTime) >= 0) break;
            SessionBase session = (SessionBase)this.expirationQueue.get(time);
            if (!(session instanceof StreamSessionImpl) || !(stream = (StreamSessionImpl)session).isNeedsAcknowledgment()) continue;
            return stream;
        }
        return null;
    }

    private void processSessionExpirationQueue() {
        this.processExpirationQueue(this.expirationQueue, session -> {
            this.sessions.remove(session.getDiscriminator());
            --allSessionCount;
            LOGGER.trace("Expired session {}", session.getDiscriminator());
            session.expire();
            this.sessionLayer.sendSession((Session)session, null);
        });
    }

    private void updateSession(SessionBase session) {
        Instant oldKey = session.getLastUpdateTime();
        this.expirationQueue.remove(oldKey);
        session.update();
        Instant newLastUpdateTime = session.getLastUpdateTime();
        SessionBase previous = this.expirationQueue.put(newLastUpdateTime, session);
        assert (previous == null);
    }

    private void closeSession(SessionBase session) {
        LOGGER.trace("Close session {}", session.getDiscriminator());
        this.sessions.remove(session.getDiscriminator());
        this.expirationQueue.remove(session.getLastUpdateTime());
        --allSessionCount;
    }

    private void prepareIcmpHeader(ByteBuffer buffer, byte type, byte code) {
        int position = buffer.position();
        buffer.put(type);
        buffer.put(code);
        buffer.putShort((short)0);
        buffer.position(position);
        short checksum = InetUtils.rfc1071Checksum(buffer);
        buffer.putShort(position + 2, checksum);
        buffer.position(position);
    }

    @Nullable
    private <S extends SessionBase, D extends SessionDiscriminator<S>> S getOrCreateSession(D discriminator, Function<D, S> factory) {
        SessionBase session = this.sessions.get(discriminator);
        if (session != null) {
            return (S)session;
        }
        if (this.sessions.size() >= Config.defaultSessionsNumberPerCardLimit) {
            LOGGER.warn("Session count per card limit has reached");
            return null;
        }
        if (allSessionCount >= Config.defaultSessionsNumberLimit) {
            LOGGER.warn("Session count limit has reached");
            return null;
        }
        ++allSessionCount;
        LOGGER.trace("New session: {}", discriminator);
        SessionBase newSession = (SessionBase)factory.apply(discriminator);
        this.sessions.put(discriminator, newSession);
        this.updateSession(newSession);
        return (S)newSession;
    }

    private void reject(ByteBuffer payload, int srcIpAddress) {
        byte[] data = InetUtils.quickICMPBody(payload);
        this.icmpReply = new ICMPReply(3, 13, 0, srcIpAddress, data);
    }

    private void sessionSendFinish(DatagramSessionBase session, ByteBuffer payload, int srcIpAddress) {
        Session.States state = session.getState();
        switch (state) {
            case NEW: {
                session.setState(Session.States.ESTABLISHED);
                break;
            }
            case REJECT: {
                this.reject(payload, srcIpAddress);
                LOGGER.trace("Reject session {}", session.getDiscriminator());
            }
            case FINISH: {
                this.closeSession(session);
                break;
            }
            case ESTABLISHED: {
                break;
            }
            default: {
                throw new IllegalStateException(state.name());
            }
        }
    }

    private SessionActions prepareTCPSegment(TransportMessage message, StreamSessionImpl stream) {
        ByteBuffer data = message.getData();
        StreamSessionDiscriminator discriminator = stream.getDiscriminator();
        int position = data.position();
        int limit = data.limit();
        data.putShort(discriminator.getDstPort());
        data.putShort(discriminator.getSrcPort());
        SessionActions recv = stream.receive(data);
        switch (recv) {
            case DROP: 
            case IGNORE: {
                data.position(position);
                data.limit(limit);
                return recv;
            }
            case FORWARD: {
                data.position(position);
                short checksum = InetUtils.transportRfc1071Checksum(data, discriminator.getDstIpAddress(), discriminator.getSrcIpAddress(), (byte)6);
                data.putShort(position + 16, checksum);
                data.position(position);
                message.updateIpv4(discriminator.getDstIpAddress(), discriminator.getSrcIpAddress());
                LOGGER.trace("Prepared TCP packet to receive {}", (Object)stream.getHeader());
                return SessionActions.FORWARD;
            }
        }
        throw new IllegalStateException();
    }

    @Override
    public byte receiveTransportMessage(TransportMessage message) {
        this.processSessionExpirationQueue();
        while (true) {
            if (this.rejectedStream != null) {
                LOGGER.trace("Rejecting stream {}", (Object)this.rejectedStream.getDiscriminator());
                SessionActions success = this.prepareTCPSegment(message, this.rejectedStream);
                assert (success == SessionActions.FORWARD);
                this.closeSession(this.rejectedStream);
                this.rejectedStream = null;
                return 6;
            }
            if (this.icmpReply != null) {
                message.updateIpv4(this.icmpReply.srcIpAddress, this.icmpReply.dstIpAddress);
                ByteBuffer data = message.getData();
                int position = data.position();
                data.putInt(0);
                data.put(this.icmpReply.payload);
                data.limit(data.position());
                data.position(position);
                this.prepareIcmpHeader(data, this.icmpReply.type, this.icmpReply.code);
                this.icmpReply = null;
                return 1;
            }
            if (this.streamToAck != null) {
                StreamSessionImpl stream = this.streamToAck;
                this.streamToAck = null;
                this.updateSession(stream);
                switch (this.prepareTCPSegment(message, stream)) {
                    case FORWARD: {
                        if (stream.isClosed()) {
                            this.closeSession(stream);
                        }
                        return 6;
                    }
                    case DROP: {
                        this.closeSession(stream);
                    }
                }
            }
            this.receiver.prepare(message.getData());
            this.sessionLayer.receiveSession(this.receiver);
            SessionBase session = this.receiver.session;
            if (session == null) {
                return 0;
            }
            this.updateSession(session);
            if (session instanceof EchoSession) {
                EchoSessionImpl echoSession = (EchoSessionImpl)session;
                switch (session.getState()) {
                    case FINISH: {
                        this.closeSession(session);
                        break;
                    }
                    case ESTABLISHED: {
                        EchoSessionDiscriminator discriminator = echoSession.getDiscriminator();
                        ByteBuffer buffer = this.receiver.getBuffer();
                        int position = buffer.position();
                        buffer.putShort(position + 4, discriminator.getIdentity());
                        buffer.putShort(position + 6, (short)echoSession.getSequenceNumber());
                        this.prepareIcmpHeader(buffer, (byte)0, (byte)0);
                        message.updateIpv4(discriminator.getDstIpAddress(), discriminator.getSrcIpAddress());
                        return 1;
                    }
                    default: {
                        throw new IllegalStateException();
                    }
                }
                continue;
            }
            if (session instanceof DatagramSession) {
                DatagramSessionImpl datagramSession = (DatagramSessionImpl)session;
                switch (session.getState()) {
                    case FINISH: {
                        this.closeSession(session);
                        break;
                    }
                    case ESTABLISHED: {
                        DatagramSessionDiscriminator discriminator = datagramSession.getDiscriminator();
                        ByteBuffer buffer = this.receiver.getBuffer();
                        int position = buffer.position();
                        buffer.putShort(position, discriminator.getDstPort());
                        buffer.putShort(position + 2, discriminator.getSrcPort());
                        buffer.putShort(position + 4, (short)buffer.remaining());
                        buffer.putShort(position + 6, (short)0);
                        short checksum = InetUtils.transportRfc1071Checksum(buffer, discriminator.getDstIpAddress(), discriminator.getSrcIpAddress(), (byte)17);
                        buffer.putShort(position + 6, checksum);
                        buffer.position(position);
                        message.updateIpv4(discriminator.getDstIpAddress(), discriminator.getSrcIpAddress());
                        return 17;
                    }
                    default: {
                        throw new IllegalStateException();
                    }
                }
                continue;
            }
            if (!(session instanceof StreamSession)) break;
            StreamSessionImpl streamSession = (StreamSessionImpl)session;
            switch (this.prepareTCPSegment(message, streamSession)) {
                case FORWARD: {
                    if (streamSession.isClosed()) {
                        this.closeSession(streamSession);
                    }
                    return 6;
                }
                case DROP: {
                    this.closeSession(streamSession);
                }
            }
        }
        throw new IllegalStateException();
    }

    @Override
    public Optional<Tag> onSave() {
        return this.sessionLayer.onSave().map(sessionLayerState -> {
            CompoundTag transportLayerState = new CompoundTag();
            transportLayerState.m_128365_("Session", sessionLayerState);
            return transportLayerState;
        });
    }

    @Override
    public void onStop() {
        for (SessionBase session : this.sessions.values()) {
            session.expire();
            this.sessionLayer.sendSession(session, null);
            this.closeSession(session);
        }
        this.sessionLayer.onStop();
    }

    @Override
    public void sendTransportMessage(byte protocol, TransportMessage message) {
        block23: {
            this.processSessionExpirationQueue();
            int srcIpAddress = message.getSrcIpv4Address();
            int dstIpAddress = message.getDstIpv4Address();
            ByteBuffer data = message.getData();
            block0 : switch (protocol) {
                case 1: {
                    if (data.remaining() < 8) {
                        return;
                    }
                    byte type = data.get();
                    byte code = data.get();
                    data.getShort();
                    if (type != 8) break;
                    if (code != 0) {
                        return;
                    }
                    short identity = data.getShort();
                    short sequence = data.getShort();
                    EchoSessionDiscriminator discriminator = new EchoSessionDiscriminator(srcIpAddress, dstIpAddress, identity);
                    EchoSessionImpl session = this.getOrCreateSession(discriminator, it -> new EchoSessionImpl(dstIpAddress, 7, (EchoSessionDiscriminator)it));
                    if (session == null) {
                        this.reject(data, srcIpAddress);
                        break;
                    }
                    session.setSequenceNumber(sequence);
                    session.setTtl(message.getTtl());
                    this.sessionLayer.sendSession(session, data);
                    this.sessionSendFinish(session, data, srcIpAddress);
                    break;
                }
                case 17: {
                    if (data.remaining() < 8) {
                        return;
                    }
                    short srcPort = data.getShort();
                    short dstPort = data.getShort();
                    int datagramLength = Short.toUnsignedInt(data.getShort());
                    data.getShort();
                    if (data.remaining() + 8 < datagramLength) {
                        return;
                    }
                    data.limit(data.position() + datagramLength - 8);
                    DatagramSessionDiscriminator discriminator = new DatagramSessionDiscriminator(srcIpAddress, srcPort, dstIpAddress, dstPort);
                    DatagramSessionImpl session = this.getOrCreateSession(discriminator, it -> new DatagramSessionImpl(dstIpAddress, dstPort, (DatagramSessionDiscriminator)it));
                    if (session == null) {
                        this.reject(data, srcIpAddress);
                        break;
                    }
                    this.sessionLayer.sendSession(session, data);
                    this.sessionSendFinish(session, data, srcIpAddress);
                    break;
                }
                case 6: {
                    short dstPort;
                    if (data.remaining() < 20) {
                        return;
                    }
                    short srcPort = data.getShort();
                    StreamSessionDiscriminator discriminator = new StreamSessionDiscriminator(srcIpAddress, srcPort, dstIpAddress, dstPort = data.getShort());
                    StreamSessionImpl session = this.getOrCreateSession(discriminator, it -> new StreamSessionImpl(dstIpAddress, dstPort, (StreamSessionDiscriminator)it));
                    if (session == null) {
                        this.reject(data, srcIpAddress);
                        break;
                    }
                    LOGGER.trace("GOT TCP");
                    switch (session.send(data)) {
                        case FORWARD: {
                            switch (session.getState()) {
                                case NEW: 
                                case FINISH: {
                                    this.sessionLayer.sendSession(session, null);
                                    break;
                                }
                                case ESTABLISHED: {
                                    this.sessionLayer.sendSession(session, session.getSendBuffer());
                                }
                            }
                            Session.States state = session.getState();
                            if (state == Session.States.REJECT || state == Session.States.FINISH) {
                                this.rejectedStream = session;
                            }
                            if (session.isNeedsAcknowledgment()) {
                                this.streamToAck = session;
                                break block0;
                            }
                            break block23;
                        }
                        case DROP: {
                            this.closeSession(session);
                        }
                    }
                }
            }
        }
    }

    private static final class SessionReceiver
    implements SessionLayer.Receiver {
        private SessionBase session = null;
        private ByteBuffer buffer = null;
        private int position = 0;
        private int limit = 0;

        private SessionReceiver() {
        }

        private void prepare(ByteBuffer buffer) {
            this.session = null;
            this.buffer = buffer;
            this.position = buffer.position();
            this.limit = buffer.limit();
        }

        private ByteBuffer getBuffer() {
            this.buffer.position(this.position);
            return this.buffer;
        }

        @Override
        @Nullable
        public ByteBuffer receive(Session session) {
            this.buffer.position(this.position);
            this.buffer.limit(this.limit);
            this.session = (SessionBase)session;
            switch (session.getState()) {
                case NEW: 
                case REJECT: 
                case FINISH: {
                    return null;
                }
                case ESTABLISHED: {
                    if (session instanceof EchoSession || session instanceof DatagramSession) {
                        this.buffer.putLong(0L);
                        return this.buffer;
                    }
                    if (session instanceof StreamSession) {
                        StreamSessionImpl stream = (StreamSessionImpl)session;
                        return stream.getReceiveBuffer();
                    }
                    throw new IllegalArgumentException("session");
                }
            }
            throw new IllegalStateException();
        }
    }

    private record ICMPReply(byte type, byte code, int srcIpAddress, int dstIpAddress, byte[] payload) {
    }
}

