#include <qdns_dane.h>

#include <fmt.h>
#include <qdns.h>

#include <dns.h>
#include <errno.h>
#include <stdlib.h>
#include <string.h>

#define DNS_T_TLSA "\0\64"
#define TLSA_DATA_LEN_SHA256 (256 / 8)
#define TLSA_DATA_LEN_SHA512 (512 / 8)
#define TLSA_MIN_RECORD_LEN 3

void daneinfo_free(struct daneinfo *di, int cnt)
{
	for (int i = 0; i < cnt; i++)
		free(di[i].data);
	free(di);
}

static int
free_tlsa_data(struct daneinfo **out, const int cnt)
{
	if (out != NULL) {
		daneinfo_free(*out, cnt);
		*out = NULL;
	}
	return -1;
}

/* taken from dns_txt_packet() of libowfat */
static int
dns_tlsa_packet(struct daneinfo **out, const char *buf, unsigned int len)
{
	char header[12];
	int ret = 0;

	if (len < sizeof(header)) {
		errno = EINVAL;
		return -1;
	}
	memcpy(header, buf, sizeof(header));

	uint16_t numanswers = ntohs(*((unsigned short *)(header + 6)));
	unsigned int pos = dns_packet_skipname(buf, len, sizeof(header));
	if (!pos)
		return -1;
	pos += 4;

	if ((out != NULL) && (numanswers > 0)) {
		*out = calloc(numanswers, sizeof(**out));
		if (*out == NULL)
			return -1;
	}

	while (numanswers--) {
		pos = dns_packet_skipname(buf, len, pos);
		if (!pos)
			return free_tlsa_data(out, ret);
		if (len < pos + 10) {
			errno = EINVAL;
			return free_tlsa_data(out, ret);
		}
		memcpy(header, buf + pos, 10);
		pos += 10;
		uint16_t datalen = ntohs(*((unsigned short *)(header + 8)));

		if (memcmp(header, DNS_T_TLSA, 2) == 0) {
			if (memcmp(header + 2, DNS_C_IN, 2) == 0) {
				unsigned int minlen;
				unsigned int maxlen;

				if (datalen <= TLSA_MIN_RECORD_LEN) {
					errno = EINVAL;
					return free_tlsa_data(out, ret);
				}

				if (pos + datalen > len) {
					errno = EINVAL;
					return free_tlsa_data(out, ret);
				}

				switch (buf[pos + 2]) {
				default:
				case TLSA_MT_Full:
					minlen = 1;
					maxlen = datalen;
					break;
				case TLSA_MT_SHA2_256:
					minlen = TLSA_DATA_LEN_SHA256;
					maxlen = TLSA_DATA_LEN_SHA256;
					break;
				case TLSA_MT_SHA2_512:
					minlen = TLSA_DATA_LEN_SHA512;
					maxlen = TLSA_DATA_LEN_SHA512;
					break;
				}

				if ((datalen < minlen + TLSA_MIN_RECORD_LEN) || (datalen > maxlen + TLSA_MIN_RECORD_LEN)) {
					errno = EINVAL;
					return free_tlsa_data(out, ret);
				}

				if (out != NULL) {
					struct daneinfo *res = *out + ret;

					res->cert_usage = buf[pos];
					res->selector = buf[pos + 1];
					res->matching_type = buf[pos + 2];
					res->datalen = datalen - TLSA_MIN_RECORD_LEN;
					res->data = malloc(res->datalen);

					if (res->data == NULL)
						return free_tlsa_data(out, ret);

					memcpy(res->data, buf + pos + TLSA_MIN_RECORD_LEN, res->datalen);
				}
				ret++;
			}
		}
		pos += datalen;
	}

	if (out != NULL) {
		/* there may have been more sub-packets than actually interesting ones,
		 * shrink the array to the needed size */
		if (ret > 0) {
			struct daneinfo *tmp = realloc(*out, sizeof(**out) * ret);
			if (tmp != NULL)
				*out = tmp;
		} else {
			free(*out);
			*out = NULL;
		}
	}

	return ret;
}

int
dnstlsa(const char *host, const unsigned short port, struct daneinfo **out)
{
	char hostbuf[strlen("_65535._tcp.") + strlen(host) + 1];
	char *q = NULL;

	hostbuf[0] = '_';
	ultostr(port, hostbuf + 1);
	strcat(hostbuf, "._tcp.");
	strcat(hostbuf, host);

	if (out != NULL)
		*out = NULL;

	if (!dns_domain_fromdot(&q, hostbuf, strlen(hostbuf)))
		return -1;
	if (dns_resolve(q, DNS_T_TLSA) == -1)
		return -1;
	int r = dns_tlsa_packet(out, dns_resolve_tx.packet, dns_resolve_tx.packetlen);
	if (r < 0)
		return r;
	dns_transmit_free(&dns_resolve_tx);
	dns_domain_free(&q);

	return r;
}
