package quic

import (
	"context"
	"time"

	"github.com/quic-go/quic-go"
	"github.com/xtls/xray-core/common/buf"
	"github.com/xtls/xray-core/common/errors"
	"github.com/xtls/xray-core/common/mux"
	"github.com/xtls/xray-core/common/net"
	"github.com/xtls/xray-core/common/serial"
	"github.com/xtls/xray-core/common/signal/done"
)

var MaxIncomingStreams = 16
var currentStream = 0

type interConn struct {
	ctx         context.Context
	quicConn    quic.Connection // small udp packet can be sent with Datagram directly 
	streams     []quic.Stream	// other packets can be sent via steam, it offer mux, reliability, fragmentation and ordering 
	readChannel chan readResult
	reader      buf.MultiBufferContainer
	done        *done.Instance
	local       net.Addr
	remote      net.Addr
}

type readResult struct {
	buffer []byte
	err    error 
}

func NewConnInitReader(ctx context.Context, quicConn quic.Connection, done *done.Instance, remote net.Addr) *interConn {
	c := &interConn{
		ctx:         ctx,
		quicConn:    quicConn,
		readChannel: make(chan readResult),
		reader:      buf.MultiBufferContainer{},
		done:        done,
		local:       quicConn.LocalAddr(),
		remote:      remote,
	}
	go func() {
		for {
			received, e := c.quicConn.ReceiveDatagram(c.ctx)
			errors.LogInfo(c.ctx, "Read ReceiveDatagram ", len(received))
			c.readChannel <- readResult{buffer: received, err: e}
		}
	}()
	go c.acceptStreams()
	return c
}

func (c *interConn) acceptStreams() {
	for {
		stream, err := c.quicConn.AcceptStream(context.Background())
		errors.LogInfo(c.ctx, "Read AcceptStream ", err)
		if err != nil {
			errors.LogInfoInner(context.Background(), err, "failed to accept stream")
			select {
			case <-c.quicConn.Context().Done():
				return
			case <-c.done.Wait():
				if err := c.quicConn.CloseWithError(0, ""); err != nil {
					errors.LogInfoInner(context.Background(), err, "failed to close connection")
				}
				return
			default:
				time.Sleep(time.Second)
				continue
			}
		}
		go c.readMuxCoolPacket(stream)
		c.streams = append(c.streams, stream)
	}
}

func (c *interConn) readMuxCoolPacket(stream quic.Stream) {
	for {
		received := make([]byte, buf.Size)
		i, e := stream.Read(received)
		if e != nil {
			errors.LogErrorInner(c.ctx, e, "Error read stream, drop this buffer ", i)
			c.readChannel <- readResult{buffer: nil, err: e}
			continue;
		}
		errors.LogInfo(c.ctx, "Read stream ", i)

		buffer := buf.New()
		buffer.Write(received[:i])
		muxCoolReader := &buf.MultiBufferContainer{}
		muxCoolReader.MultiBuffer = append(muxCoolReader.MultiBuffer, buffer)
		var meta mux.FrameMetadata
		err := meta.Unmarshal(muxCoolReader)
		if err != nil {
			errors.LogInfo(c.ctx, "Not a Mux Cool packet beginning, copy directly ", i)
			buf.ReleaseMulti(muxCoolReader.MultiBuffer)
			c.readChannel <- readResult{buffer: received[:i], err: e}
			continue;
		}
		if !meta.Option.Has(mux.OptionData) {
			errors.LogInfo(c.ctx, "No option data, copy directly ", i)
			buf.ReleaseMulti(muxCoolReader.MultiBuffer)
			c.readChannel <- readResult{buffer: received[:i], err: e}
			continue;
		}
		size, err := serial.ReadUint16(muxCoolReader)
		remaining := uint16(muxCoolReader.MultiBuffer.Len())
		errors.LogInfo(c.ctx, "Read stream ", i, " option size ", size, " remaining size ", remaining)
		if err != nil || size <= remaining || size > remaining + 1500 {
			errors.LogInfo(c.ctx, "do not wait for second part of UDP packet ", i)
			buf.ReleaseMulti(muxCoolReader.MultiBuffer)
			c.readChannel <- readResult{buffer: received[:i], err: e}
			continue;
		}

		i2, e := stream.Read(received[i:])
		if e != nil {
			errors.LogErrorInner(c.ctx, e, "Error read stream, drop this buffer ", i2)
			buf.ReleaseMulti(muxCoolReader.MultiBuffer)
			c.readChannel <- readResult{buffer: nil, err: e}
			continue;
		}
		errors.LogInfo(c.ctx, "Read stream i2 size ", i2)
		buf.ReleaseMulti(muxCoolReader.MultiBuffer)
		c.readChannel <- readResult{buffer: received[:(i + i2)], err: e}
	}
}

func (c *interConn) Read(b []byte) (int, error) {	
	if c.reader.MultiBuffer.Len() > 0 {
		return c.reader.Read(b)
	}
	received := <- c.readChannel
	if received.err != nil {
		return 0, received.err
	}
	buffer := buf.New()
	buffer.Write(received.buffer)
	c.reader.MultiBuffer = append(c.reader.MultiBuffer, buffer)
	errors.LogInfo(c.ctx, "Read copy ", len(received.buffer))
	return c.reader.Read(b)
}

func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
	mb = buf.Compact(mb)
	mb, err := buf.WriteMultiBuffer(c, mb)
	buf.ReleaseMulti(mb)
	return err
}

func (c *interConn) Write(b []byte) (int, error) {
	if len(b) > 1240 { // TODO: why quic-go increase internal MTU causing packet loss?
		if len(c.streams) < MaxIncomingStreams {
			stream, err := c.quicConn.OpenStream()
			errors.LogInfo(c.ctx, "Write OpenStream ", err)
			if err == nil {
				c.streams = append(c.streams, stream)
			} else {
				errors.LogInfoInner(c.ctx, err, "failed to openStream: ")
			}
		}
		currentStream++;
		if currentStream > len(c.streams) - 1 {
			currentStream = 0;
		}
		errors.LogInfo(c.ctx, "Write stream ", len(b), currentStream, len(c.streams))
		return c.streams[currentStream].Write(b)
	}
	var err = c.quicConn.SendDatagram(b)
	errors.LogInfo(c.ctx, "Write SendDatagram ", len(b), err)
	if _, ok := err.(*quic.DatagramTooLargeError); ok {
		if len(c.streams) < MaxIncomingStreams {
			stream, err := c.quicConn.OpenStream()
			errors.LogInfo(c.ctx, "Write OpenStream ", err)
			if err == nil {
				c.streams = append(c.streams, stream)
			} else {
				errors.LogInfoInner(c.ctx, err, "failed to openStream: ")
			}
		}
		currentStream++;
		if currentStream > len(c.streams) - 1 {
			currentStream = 0;
		}
		errors.LogInfo(c.ctx, "Write stream ", len(b), currentStream, len(c.streams))
		return c.streams[currentStream].Write(b)
	}
	if err != nil {
		return 0, err
	}
	return len(b), nil
}

func (c *interConn) Close() error {
	var err error
	for _, s := range c.streams {
		e := s.Close()
		if e != nil {
			err = e
		}
	}
	return err
}

func (c *interConn) LocalAddr() net.Addr {
	return c.local
}

func (c *interConn) RemoteAddr() net.Addr {
	return c.remote
}

func (c *interConn) SetDeadline(t time.Time) error {
	var err error
	for _, s := range c.streams {
		e := s.SetDeadline(t)
		if e != nil {
			err = e
		}
	}
	return err
}

func (c *interConn) SetReadDeadline(t time.Time) error {
	var err error
	for _, s := range c.streams {
		e := s.SetReadDeadline(t)
		if e != nil {
			err = e
		}
	}
	return err
}

func (c *interConn) SetWriteDeadline(t time.Time) error {
	var err error
	for _, s := range c.streams {
		e := s.SetWriteDeadline(t)
		if e != nil {
			err = e
		}
	}
	return err
}