package dns

import (
	"context"
	go_errors "errors"
	"github.com/xtls/xray-core/common"
	"github.com/xtls/xray-core/common/errors"
	"github.com/xtls/xray-core/common/net"
	"github.com/xtls/xray-core/common/signal/pubsub"
	"github.com/xtls/xray-core/common/task"
	dns_feature "github.com/xtls/xray-core/features/dns"
	"golang.org/x/net/dns/dnsmessage"
	"sync"
	"time"
)

type CacheController struct {
	sync.RWMutex
	ips          map[string]*record
	pub          *pubsub.Service
	cacheCleanup *task.Periodic
	name         string
	disableCache bool
}

func NewCacheController(name string, disableCache bool) *CacheController {
	c := &CacheController{
		name:         name,
		disableCache: disableCache,
		ips:          make(map[string]*record),
		pub:          pubsub.NewService(),
	}

	c.cacheCleanup = &task.Periodic{
		Interval: time.Minute,
		Execute:  c.CacheCleanup,
	}
	return c
}

// CacheCleanup clears expired items from cache
func (c *CacheController) CacheCleanup() error {
	now := time.Now()
	c.Lock()
	defer c.Unlock()

	if len(c.ips) == 0 {
		return errors.New("nothing to do. stopping...")
	}

	for domain, record := range c.ips {
		if record.A != nil && record.A.Expire.Before(now) {
			record.A = nil
		}
		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
			record.AAAA = nil
		}

		if record.A == nil && record.AAAA == nil {
			errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
			delete(c.ips, domain)
		} else {
			c.ips[domain] = record
		}
	}

	if len(c.ips) == 0 {
		c.ips = make(map[string]*record)
	}

	return nil
}

func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
	elapsed := time.Since(req.start)

	c.Lock()
	rec, found := c.ips[req.domain]
	if !found {
		rec = &record{}
	}

	switch req.reqType {
	case dnsmessage.TypeA:
		rec.A = ipRec
	case dnsmessage.TypeAAAA:
		rec.AAAA = ipRec
	}

	errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
	c.ips[req.domain] = rec

	switch req.reqType {
	case dnsmessage.TypeA:
		c.pub.Publish(req.domain+"4", nil)
		if !c.disableCache {
			_, _, err := rec.AAAA.getIPs()
			if !go_errors.Is(err, errRecordNotFound) {
				c.pub.Publish(req.domain+"6", nil)
			}
		}
	case dnsmessage.TypeAAAA:
		c.pub.Publish(req.domain+"6", nil)
		if !c.disableCache {
			_, _, err := rec.A.getIPs()
			if !go_errors.Is(err, errRecordNotFound) {
				c.pub.Publish(req.domain+"4", nil)
			}
		}
	}

	c.Unlock()
	common.Must(c.cacheCleanup.Start())
}

func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
	c.RLock()
	record, found := c.ips[domain]
	c.RUnlock()

	if !found {
		return nil, 0, errRecordNotFound
	}

	var errs []error
	var allIPs []net.IP
	var rTTL uint32 = dns_feature.DefaultTTL

	mergeReq := option.IPv4Enable && option.IPv6Enable

	if option.IPv4Enable {
		ips, ttl, err := record.A.getIPs()
		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
			return ips, ttl, err
		}
		if ttl < rTTL {
			rTTL = ttl
		}
		if len(ips) > 0 {
			allIPs = append(allIPs, ips...)
		} else {
			errs = append(errs, err)
		}
	}

	if option.IPv6Enable {
		ips, ttl, err := record.AAAA.getIPs()
		if !mergeReq || go_errors.Is(err, errRecordNotFound) {
			return ips, ttl, err
		}
		if ttl < rTTL {
			rTTL = ttl
		}
		if len(ips) > 0 {
			allIPs = append(allIPs, ips...)
		} else {
			errs = append(errs, err)
		}
	}

	if len(allIPs) > 0 {
		return allIPs, rTTL, nil
	}
	if go_errors.Is(errs[0], errs[1]) {
		return nil, rTTL, errs[0]
	}
	return nil, rTTL, errors.Combine(errs...)
}

func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
	// ipv4 and ipv6 belong to different subscription groups
	if option.IPv4Enable {
		sub4 = c.pub.Subscribe(domain + "4")
	}
	if option.IPv6Enable {
		sub6 = c.pub.Subscribe(domain + "6")
	}
	return
}

func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
	if sub4 != nil {
		sub4.Close()
	}
	if sub6 != nil {
		sub6.Close()
	}
}