/*	$OpenBSD: iobuf.c,v 1.8 2015/12/05 21:27:42 mmcc Exp $	*/
/*      
 * Copyright (c) 2012 Eric Faurot <eric@openbsd.org>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>

#include <errno.h>
#include <limits.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#ifdef IO_SSL
#include <openssl/err.h>
#include <openssl/ssl.h>
#endif

#include "iobuf.h"

#define IOBUF_MAX	65536
#define IOBUFQ_MIN	4096

struct ioqbuf	*ioqbuf_alloc(struct iobuf *, size_t);
void		 iobuf_drain(struct iobuf *, size_t);

int
iobuf_init(struct iobuf *io, size_t size, size_t max)
{
	memset(io, 0, sizeof *io);

	if (max == 0)
		max = IOBUF_MAX;

	if (size == 0)
		size = max;

	if (size > max)
		return (-1);

	if ((io->buf = malloc(size)) == NULL)
		return (-1);

	io->size = size;
	io->max = max;

	return (0);
}

void
iobuf_clear(struct iobuf *io)
{
	struct ioqbuf	*q;

	free(io->buf);

	while ((q = io->outq)) {
		io->outq = q->next;
		free(q);
	}

	memset(io, 0, sizeof (*io));
}

void
iobuf_drain(struct iobuf *io, size_t n)
{
	struct	ioqbuf	*q;
	size_t		 left = n;

	while ((q = io->outq) && left) {
		if ((q->wpos - q->rpos) > left) {
			q->rpos += left;
			left = 0;
		} else {
			left -= q->wpos - q->rpos;
			io->outq = q->next;
			free(q);
		}
	}

	io->queued -= (n - left);
	if (io->outq == NULL)
		io->outqlast = NULL;
}

int
iobuf_extend(struct iobuf *io, size_t n)
{
	char	*t;

	if (n > io->max)
		return (-1);

	if (io->max - io->size < n)
		return (-1);

	t = realloc(io->buf, io->size + n);
	if (t == NULL)
		return (-1);

	io->size += n;
	io->buf = t;

	return (0);
}

size_t
iobuf_left(struct iobuf *io)
{
	return io->size - io->wpos;
}

size_t
iobuf_space(struct iobuf *io)
{
	return io->size - (io->wpos - io->rpos);
}

size_t
iobuf_len(struct iobuf *io)
{
	return io->wpos - io->rpos;
}

char *
iobuf_data(struct iobuf *io)
{
	return io->buf + io->rpos;
}

void
iobuf_drop(struct iobuf *io, size_t n)
{
	if (n >= iobuf_len(io)) {
		io->rpos = io->wpos = 0;
		return;
	}

	io->rpos += n;
}

char *
iobuf_getline(struct iobuf *iobuf, size_t *rlen)
{
	char	*buf;
	size_t	 len, i;

	buf = iobuf_data(iobuf);
	len = iobuf_len(iobuf);

	for (i = 0; i + 1 <= len; i++)
		if (buf[i] == '\n') {
			/* Note: the returned address points into the iobuf
			 * buffer.  We NUL-end it for convenience, and discard
			 * the data from the iobuf, so that the caller doesn't
			 * have to do it.  The data remains "valid" as long
			 * as the iobuf does not overwrite it, that is until
			 * the next call to iobuf_normalize() or iobuf_extend().
			 */
			iobuf_drop(iobuf, i + 1);
			len = (i && buf[i - 1] == '\r') ? i - 1 : i;
			buf[len] = '\0';
			if (rlen)
				*rlen = len;
			return (buf);
		}

	return (NULL);
}

void
iobuf_normalize(struct iobuf *io)
{
	if (io->rpos == 0)
		return;

	if (io->rpos == io->wpos) {
		io->rpos = io->wpos = 0;
		return;
	}

	memmove(io->buf, io->buf + io->rpos, io->wpos - io->rpos);
	io->wpos -= io->rpos;
	io->rpos = 0;
}

ssize_t
iobuf_read(struct iobuf *io, int fd)
{
	ssize_t	n;

	n = read(fd, io->buf + io->wpos, iobuf_left(io));
	if (n == -1) {
		/* XXX is this really what we want? */
		if (errno == EAGAIN || errno == EINTR)
			return (IOBUF_WANT_READ);
		return (IOBUF_ERROR);
	}
	if (n == 0)
		return (IOBUF_CLOSED);

	io->wpos += n;

	return (n);
}

struct ioqbuf *
ioqbuf_alloc(struct iobuf *io, size_t len)
{
	struct ioqbuf   *q;

	if (len < IOBUFQ_MIN)
		len = IOBUFQ_MIN;

	if ((q = malloc(sizeof(*q) + len)) == NULL)
		return (NULL);

	q->rpos = 0;
	q->wpos = 0;
	q->size = len;
	q->next = NULL;
	q->buf = (char *)(q) + sizeof(*q);

	if (io->outqlast == NULL)
		io->outq = q;
	else
		io->outqlast->next = q;
	io->outqlast = q;

	return (q);
}

size_t
iobuf_queued(struct iobuf *io)
{
	return io->queued;
}

void *
iobuf_reserve(struct iobuf *io, size_t len)
{
	struct ioqbuf	*q;
	void		*r;

	if (len == 0)
		return (NULL);

	if (((q = io->outqlast) == NULL) || q->size - q->wpos <= len) {
		if ((q = ioqbuf_alloc(io, len)) == NULL)
			return (NULL);
	}

	r = q->buf + q->wpos;
	q->wpos += len;
	io->queued += len;

	return (r);
}

int
iobuf_queue(struct iobuf *io, const void *data, size_t len)
{
	void	*buf;

	if (len == 0)
		return (0);

	if ((buf = iobuf_reserve(io, len)) == NULL)
		return (-1);

	memmove(buf, data, len);

	return (len);
}

int
iobuf_queuev(struct iobuf *io, const struct iovec *iov, int iovcnt)
{
	int	 i;
	size_t	 len = 0;
	char	*buf;

	for (i = 0; i < iovcnt; i++)
		len += iov[i].iov_len;

	if ((buf = iobuf_reserve(io, len)) == NULL)
		return (-1);

	for (i = 0; i < iovcnt; i++) {
		if (iov[i].iov_len == 0)
			continue;
		memmove(buf, iov[i].iov_base, iov[i].iov_len);
		buf += iov[i].iov_len;
	}

	return (0);

}

int
iobuf_fqueue(struct iobuf *io, const char *fmt, ...)
{
	va_list	ap;
	int	len;

	va_start(ap, fmt);
	len = iobuf_vfqueue(io, fmt, ap);
	va_end(ap);

	return (len);
}

int
iobuf_vfqueue(struct iobuf *io, const char *fmt, va_list ap)
{
	char	*buf;
	int	 len;

	len = vasprintf(&buf, fmt, ap);

	if (len == -1)
		return (-1);

	len = iobuf_queue(io, buf, len);
	free(buf);

	return (len);
}

ssize_t
iobuf_write(struct iobuf *io, int fd)
{
	struct iovec	 iov[IOV_MAX];
	struct ioqbuf	*q;
	int		 i;
	ssize_t		 n;

	i = 0;
	for (q = io->outq; q ; q = q->next) {
		if (i >= IOV_MAX)
			break;
		iov[i].iov_base = q->buf + q->rpos;
		iov[i].iov_len = q->wpos - q->rpos;
		i++;
	}

	n = writev(fd, iov, i);
	if (n == -1) {
		if (errno == EAGAIN || errno == EINTR)
			return (IOBUF_WANT_WRITE);
		if (errno == EPIPE)
			return (IOBUF_CLOSED);
		return (IOBUF_ERROR);
	}

	iobuf_drain(io, n);

	return (n);
}

int
iobuf_flush(struct iobuf *io, int fd)
{
	ssize_t	s;

	while (io->queued)
		if ((s = iobuf_write(io, fd)) < 0)
			return (s);

	return (0);
}

#ifdef IO_SSL

int
iobuf_flush_ssl(struct iobuf *io, void *ssl)
{
	ssize_t	s;

	while (io->queued)
		if ((s = iobuf_write_ssl(io, ssl)) < 0)
			return (s);

	return (0);
}

ssize_t
iobuf_write_ssl(struct iobuf *io, void *ssl)
{
	struct ioqbuf	*q;
	int		 r;
	ssize_t		 n;

	q = io->outq;
	n = SSL_write(ssl, q->buf + q->rpos, q->wpos - q->rpos);
	if (n <= 0) {
		switch ((r = SSL_get_error(ssl, n))) {
		case SSL_ERROR_WANT_READ:
			return (IOBUF_WANT_READ);
		case SSL_ERROR_WANT_WRITE:
			return (IOBUF_WANT_WRITE);
		case SSL_ERROR_ZERO_RETURN: /* connection closed */
			return (IOBUF_CLOSED);
		case SSL_ERROR_SYSCALL:
			if (ERR_peek_last_error())
				return (IOBUF_SSLERROR);
			if (r == 0)
				errno = EPIPE;
			return (IOBUF_ERROR);
		default:
			return (IOBUF_SSLERROR);
		}
	}
	iobuf_drain(io, n);

	return (n);
}

ssize_t
iobuf_read_ssl(struct iobuf *io, void *ssl)
{
	ssize_t	n;
	int	r;

	n = SSL_read(ssl, io->buf + io->wpos, iobuf_left(io));
	if (n < 0) {
		switch ((r = SSL_get_error(ssl, n))) {
		case SSL_ERROR_WANT_READ:
			return (IOBUF_WANT_READ);
		case SSL_ERROR_WANT_WRITE:
			return (IOBUF_WANT_WRITE);
		case SSL_ERROR_SYSCALL:
			if (ERR_peek_last_error())
				return (IOBUF_SSLERROR);
			if (r == 0)
				errno = EPIPE;
			return (IOBUF_ERROR);
		default:
			return (IOBUF_SSLERROR);
		}
	} else if (n == 0)
		return (IOBUF_CLOSED);

	io->wpos += n;

	return (n);
}

#endif /* IO_SSL */