/* $OpenBSD: math_2n.c,v 1.26 2007/04/16 13:01:39 moritz Exp $	 */
/* $EOM: math_2n.c,v 1.15 1999/04/20 09:23:30 niklas Exp $	 */

/*
 * Copyright (c) 1998 Niels Provos.  All rights reserved.
 * Copyright (c) 1999 Niklas Hallqvist.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/*
 * This code was written under funding by Ericsson Radio Systems.
 */

/*
 * B2N is a module for doing arithmetic on the Field GF(2**n) which is
 * isomorph to ring of polynomials GF(2)[x]/p(x) where p(x) is an
 * irreducible polynomial over GF(2)[x] with grade n.
 *
 * First we need functions which operate on GF(2)[x], operation
 * on GF(2)[x]/p(x) can be done as for Z_p then.
 */

#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "math_2n.h"
#include "util.h"

static u_int8_t hex2int(char);

CHUNK_TYPE      b2n_mask[CHUNK_BITS] = {
	0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
#if CHUNK_BITS > 8
	0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000, 0x8000,
#if CHUNK_BITS > 16
	0x00010000, 0x00020000, 0x00040000, 0x00080000,
	0x00100000, 0x00200000, 0x00400000, 0x00800000,
	0x01000000, 0x02000000, 0x04000000, 0x08000000,
	0x10000000, 0x20000000, 0x40000000, 0x80000000,
#endif
#endif
};

/* Convert a hex character to its integer value.  */
static u_int8_t
hex2int(char c)
{
	if (c <= '9')
		return c - '0';
	if (c <= 'f')
		return 10 + c - 'a';

	return 0;
}

int
b2n_random(b2n_ptr n, u_int32_t bits)
{
	if (b2n_resize(n, (CHUNK_MASK + bits) >> CHUNK_SHIFTS))
		return -1;

	getrandom((u_int8_t *) n->limp, CHUNK_BYTES * n->chunks);

	/* Get the number of significant bits right */
	if (bits & CHUNK_MASK) {
		CHUNK_TYPE m =
		    (((1 << ((bits & CHUNK_MASK) - 1)) - 1) << 1) | 1;
		n->limp[n->chunks - 1] &= m;
	}
	n->dirty = 1;
	return 0;
}

/* b2n management functions */

void
b2n_init(b2n_ptr n)
{
	n->chunks = 0;
	n->limp = 0;
}

void
b2n_clear(b2n_ptr n)
{
	free(n->limp);
}

int
b2n_resize(b2n_ptr n, unsigned int chunks)
{
	size_t          old = n->chunks;
	size_t          size;
	CHUNK_TYPE     *new;

	if (chunks == 0)
		chunks = 1;

	if (chunks == old)
		return 0;

	size = CHUNK_BYTES * chunks;

	new = realloc(n->limp, size);
	if (!new)
		return -1;

	n->limp = new;
	n->chunks = chunks;
	n->bits = chunks << CHUNK_SHIFTS;
	n->dirty = 1;

	if (chunks > old)
		bzero(n->limp + old, size - CHUNK_BYTES * old);

	return 0;
}

/* Simple assignment functions.  */

int
b2n_set(b2n_ptr d, b2n_ptr s)
{
	if (d == s)
		return 0;

	b2n_sigbit(s);
	if (b2n_resize(d, (CHUNK_MASK + s->bits) >> CHUNK_SHIFTS))
		return -1;
	memcpy(d->limp, s->limp, CHUNK_BYTES * d->chunks);
	d->bits = s->bits;
	d->dirty = s->dirty;
	return 0;
}

int
b2n_set_null(b2n_ptr n)
{
	if (b2n_resize(n, 1))
		return -1;
	n->limp[0] = n->bits = n->dirty = 0;
	return 0;
}

int
b2n_set_ui(b2n_ptr n, unsigned int val)
{
#if CHUNK_BITS < 32
	int	i, chunks;

	chunks = (CHUNK_BYTES - 1 + sizeof(val)) / CHUNK_BYTES;

	if (b2n_resize(n, chunks))
		return -1;

	for (i = 0; i < chunks; i++) {
		n->limp[i] = val & CHUNK_BMASK;
		val >>= CHUNK_BITS;
	}
#else
	if (b2n_resize(n, 1))
		return -1;
	n->limp[0] = val;
#endif
	n->dirty = 1;
	return 0;
}

/* XXX This one only takes hex at the moment.  */
int
b2n_set_str(b2n_ptr n, char *str)
{
	int		i, j, w, len, chunks;
	CHUNK_TYPE      tmp;

	if (strncasecmp(str, "0x", 2))
		return -1;

	/* Make the hex string even lengthed */
	len = strlen(str) - 2;
	if (len & 1) {
		len++;
		str++;
	} else
		str += 2;

	len /= 2;

	chunks = (CHUNK_BYTES - 1 + len) / CHUNK_BYTES;
	if (b2n_resize(n, chunks))
		return -1;
	bzero(n->limp, CHUNK_BYTES * n->chunks);

	for (w = 0, i = 0; i < chunks; i++) {
		tmp = 0;
		for (j = (i == 0 ?
		    ((len - 1) % CHUNK_BYTES) + 1 : CHUNK_BYTES);
		    j > 0; j--) {
			tmp <<= 8;
			tmp |= (hex2int(str[w]) << 4) | hex2int(str[w + 1]);
			w += 2;
		}
		n->limp[chunks - 1 - i] = tmp;
	}

	n->dirty = 1;
	return 0;
}

/* Arithmetic functions.  */

u_int32_t
b2n_sigbit(b2n_ptr n)
{
	int	i, j;

	if (!n->dirty)
		return n->bits;

	for (i = n->chunks - 1; i > 0; i--)
		if (n->limp[i])
			break;

	if (!n->limp[i])
		return 0;

	for (j = CHUNK_MASK; j > 0; j--)
		if (n->limp[i] & b2n_mask[j])
			break;

	n->bits = (i << CHUNK_SHIFTS) + j + 1;
	n->dirty = 0;
	return n->bits;
}

/* Addition on GF(2)[x] is nice, its just an XOR.  */
int
b2n_add(b2n_ptr d, b2n_ptr a, b2n_ptr b)
{
	int             i;
	b2n_ptr         bmin, bmax;

	if (!b2n_cmp_null(a))
		return b2n_set(d, b);

	if (!b2n_cmp_null(b))
		return b2n_set(d, a);

	bmin = B2N_MIN(a, b);
	bmax = B2N_MAX(a, b);

	if (b2n_resize(d, bmax->chunks))
		return -1;

	for (i = 0; i < bmin->chunks; i++)
		d->limp[i] = bmax->limp[i] ^ bmin->limp[i];

	/*
	 * If d is not bmax, we have to copy the rest of the bytes, and also
	 * need to adjust to number of relevant bits.
         */
	if (d != bmax) {
		for (; i < bmax->chunks; i++)
			d->limp[i] = bmax->limp[i];

		d->bits = bmax->bits;
	}
	/*
	 * Help to converse memory. When the result of the addition is zero
	 * truncate the used amount of memory.
         */
	if (d != bmax && !b2n_cmp_null(d))
		return b2n_set_null(d);
	else
		d->dirty = 1;
	return 0;
}

/* Compare two polynomials.  */
int
b2n_cmp(b2n_ptr n, b2n_ptr m)
{
	int	sn, sm;
	int	i;

	sn = b2n_sigbit(n);
	sm = b2n_sigbit(m);

	if (sn > sm)
		return 1;
	if (sn < sm)
		return -1;

	for (i = n->chunks - 1; i >= 0; i--)
		if (n->limp[i] > m->limp[i])
			return 1;
		else if (n->limp[i] < m->limp[i])
			return -1;

	return 0;
}

int
b2n_cmp_null(b2n_ptr a)
{
	int	i = 0;

	do {
		if (a->limp[i])
			return 1;
	} while (++i < a->chunks);

	return 0;
}

/* Left shift, needed for polynomial multiplication.  */
int
b2n_lshift(b2n_ptr d, b2n_ptr n, unsigned int s)
{
	int             i, maj, min, chunks;
	u_int16_t       bits = b2n_sigbit(n), add;
	CHUNK_TYPE     *p, *op;

	if (!s)
		return b2n_set(d, n);

	maj = s >> CHUNK_SHIFTS;
	min = s & CHUNK_MASK;

	add = (!(bits & CHUNK_MASK) ||
	    ((bits & CHUNK_MASK) + min) > CHUNK_MASK) ? 1 : 0;
	chunks = n->chunks;
	if (b2n_resize(d, chunks + maj + add))
		return -1;
	memmove(d->limp + maj, n->limp, CHUNK_BYTES * chunks);

	if (maj)
		bzero(d->limp, CHUNK_BYTES * maj);
	if (add)
		d->limp[d->chunks - 1] = 0;

	/* If !min there are no bit shifts, we are done */
	if (!min)
		return 0;

	op = p = &d->limp[d->chunks - 1];
	for (i = d->chunks - 2; i >= maj; i--) {
		op--;
		*p = (*p << min) | (*op >> (CHUNK_BITS - min));
		p--;
	}
	*p <<= min;

	d->dirty = 0;
	d->bits = bits + (maj << CHUNK_SHIFTS) + min;
	return 0;
}

/* Right shift, needed for polynomial division.  */
int
b2n_rshift(b2n_ptr d, b2n_ptr n, unsigned int s)
{
	int	maj, min, size = n->chunks, newsize;
	b2n_ptr	tmp;

	if (!s)
		return b2n_set(d, n);

	maj = s >> CHUNK_SHIFTS;

	newsize = size - maj;

	if (size < maj)
		return b2n_set_null(d);

	min = (CHUNK_BITS - (s & CHUNK_MASK)) & CHUNK_MASK;
	if (min) {
		if ((b2n_sigbit(n) & CHUNK_MASK) > (u_int32_t) min)
			newsize++;

		if (b2n_lshift(d, n, min))
			return -1;
		tmp = d;
	} else
		tmp = n;

	memmove(d->limp, tmp->limp + maj + (min ? 1 : 0),
	    CHUNK_BYTES * newsize);
	if (b2n_resize(d, newsize))
		return -1;

	d->bits = tmp->bits - ((maj + (min ? 1 : 0)) << CHUNK_SHIFTS);
	return 0;
}

/* Normal polynomial multiplication.  */
int
b2n_mul(b2n_ptr d, b2n_ptr n, b2n_ptr m)
{
	int	i, j;
	b2n_t	tmp, tmp2;

	if (!b2n_cmp_null(m) || !b2n_cmp_null(n))
		return b2n_set_null(d);

	if (b2n_sigbit(m) == 1)
		return b2n_set(d, n);

	if (b2n_sigbit(n) == 1)
		return b2n_set(d, m);

	b2n_init(tmp);
	b2n_init(tmp2);

	if (b2n_set(tmp, B2N_MAX(n, m)))
		goto fail;
	if (b2n_set(tmp2, B2N_MIN(n, m)))
		goto fail;

	if (b2n_set_null(d))
		goto fail;

	for (i = 0; i < tmp2->chunks; i++)
		if (tmp2->limp[i])
			for (j = 0; j < CHUNK_BITS; j++) {
				if (tmp2->limp[i] & b2n_mask[j])
					if (b2n_add(d, d, tmp))
						goto fail;

				if (b2n_lshift(tmp, tmp, 1))
					goto fail;
			}
		else if (b2n_lshift(tmp, tmp, CHUNK_BITS))
			goto fail;

	b2n_clear(tmp);
	b2n_clear(tmp2);
	return 0;

fail:
	b2n_clear(tmp);
	b2n_clear(tmp2);
	return -1;
}

/*
 * Squaring in this polynomial ring is more efficient than normal
 * multiplication.
 */
int
b2n_square(b2n_ptr d, b2n_ptr n)
{
	int	i, j, maj, min, bits, chunk;
	b2n_t	t;

	maj = b2n_sigbit(n);
	min = maj & CHUNK_MASK;
	maj = (maj + CHUNK_MASK) >> CHUNK_SHIFTS;

	b2n_init(t);
	if (b2n_resize(t,
	    2 * maj + ((CHUNK_MASK + 2 * min) >> CHUNK_SHIFTS))) {
		b2n_clear(t);
		return -1;
	}
	chunk = 0;
	bits = 0;

	for (i = 0; i < maj; i++)
		if (n->limp[i])
			for (j = 0; j < CHUNK_BITS; j++) {
				if (n->limp[i] & b2n_mask[j])
					t->limp[chunk] ^= b2n_mask[bits];

				bits += 2;
				if (bits >= CHUNK_BITS) {
					chunk++;
					bits &= CHUNK_MASK;
				}
			}
		else
			chunk += 2;

	t->dirty = 1;
	B2N_SWAP(d, t);
	b2n_clear(t);
	return 0;
}

/*
 * Normal polynomial division.
 * These functions are far from optimal in speed.
 */
int
b2n_div_r(b2n_ptr r, b2n_ptr n, b2n_ptr m)
{
	b2n_t	q;
	int	rv;

	b2n_init(q);
	rv = b2n_div(q, r, n, m);
	b2n_clear(q);
	return rv;
}

int
b2n_div(b2n_ptr q, b2n_ptr r, b2n_ptr n, b2n_ptr m)
{
	int		i, j, len, bits;
	u_int32_t	sm, sn;
	b2n_t		nenn, div, shift, mask;

	/* If Teiler > Zaehler, the result is 0 */
	if ((sm = b2n_sigbit(m)) > (sn = b2n_sigbit(n))) {
		if (b2n_set_null(q))
			return -1;
		return b2n_set(r, n);
	}
	if (sm == 0)
		/* Division by Zero */
		return -1;
	else if (sm == 1) {
		/* Division by the One-Element */
		if (b2n_set(q, n))
			return -1;
		return b2n_set_null(r);
	}
	b2n_init(nenn);
	b2n_init(div);
	b2n_init(shift);
	b2n_init(mask);

	if (b2n_set(nenn, n))
		goto fail;
	if (b2n_set(div, m))
		goto fail;
	if (b2n_set(shift, m))
		goto fail;
	if (b2n_set_ui(mask, 1))
		goto fail;

	if (b2n_resize(q, (sn - sm + CHUNK_MASK) >> CHUNK_SHIFTS))
		goto fail;
	bzero(q->limp, CHUNK_BYTES * q->chunks);

	if (b2n_lshift(shift, shift, sn - sm))
		goto fail;
	if (b2n_lshift(mask, mask, sn - sm))
		goto fail;

	/* Number of significant octets */
	len = (sn - 1) >> CHUNK_SHIFTS;
	/* The first iteration is done over the relevant bits */
	bits = (CHUNK_MASK + sn) & CHUNK_MASK;
	for (i = len; i >= 0 && b2n_sigbit(nenn) >= sm; i--)
		for (j = (i == len ? bits : CHUNK_MASK); j >= 0 &&
		    b2n_sigbit(nenn) >= sm; j--) {
			if (nenn->limp[i] & b2n_mask[j]) {
				if (b2n_sub(nenn, nenn, shift))
					goto fail;
				if (b2n_add(q, q, mask))
					goto fail;
			}
			if (b2n_rshift(shift, shift, 1))
				goto fail;
			if (b2n_rshift(mask, mask, 1))
				goto fail;
		}

	B2N_SWAP(r, nenn);

	b2n_clear(nenn);
	b2n_clear(div);
	b2n_clear(shift);
	b2n_clear(mask);
	return 0;

fail:
	b2n_clear(nenn);
	b2n_clear(div);
	b2n_clear(shift);
	b2n_clear(mask);
	return -1;
}

/* Functions for Operation on GF(2**n) ~= GF(2)[x]/p(x).  */
int
b2n_mod(b2n_ptr m, b2n_ptr n, b2n_ptr p)
{
	int	bits, size;

	if (b2n_div_r(m, n, p))
		return -1;

	bits = b2n_sigbit(m);
	size = ((CHUNK_MASK + bits) >> CHUNK_SHIFTS);
	if (size == 0)
		size = 1;
	if (m->chunks > size)
		if (b2n_resize(m, size))
			return -1;

	m->bits = bits;
	m->dirty = 0;
	return 0;
}

int
b2n_mul_inv(b2n_ptr ga, b2n_ptr be, b2n_ptr p)
{
	b2n_t	a;

	b2n_init(a);
	if (b2n_set_ui(a, 1))
		goto fail;

	if (b2n_div_mod(ga, a, be, p))
		goto fail;

	b2n_clear(a);
	return 0;

fail:
	b2n_clear(a);
	return -1;
}

int
b2n_div_mod(b2n_ptr ga, b2n_ptr a, b2n_ptr be, b2n_ptr p)
{
	b2n_t	s0, s1, s2, q, r0, r1;

	/* There is no multiplicative inverse to Null.  */
	if (!b2n_cmp_null(be))
		return b2n_set_null(ga);

	b2n_init(s0);
	b2n_init(s1);
	b2n_init(s2);
	b2n_init(r0);
	b2n_init(r1);
	b2n_init(q);

	if (b2n_set(r0, p))
		goto fail;
	if (b2n_set(r1, be))
		goto fail;

	if (b2n_set_null(s0))
		goto fail;
	if (b2n_set(s1, a))
		goto fail;

	while (b2n_cmp_null(r1)) {
		if (b2n_div(q, r0, r0, r1))
			goto fail;
		B2N_SWAP(r0, r1);

		if (b2n_mul(s2, q, s1))
			goto fail;
		if (b2n_mod(s2, s2, p))
			goto fail;
		if (b2n_sub(s2, s0, s2))
			goto fail;

		B2N_SWAP(s0, s1);
		B2N_SWAP(s1, s2);
	}
	B2N_SWAP(ga, s0);

	b2n_clear(s0);
	b2n_clear(s1);
	b2n_clear(s2);
	b2n_clear(r0);
	b2n_clear(r1);
	b2n_clear(q);
	return 0;

fail:
	b2n_clear(s0);
	b2n_clear(s1);
	b2n_clear(s2);
	b2n_clear(r0);
	b2n_clear(r1);
	b2n_clear(q);
	return -1;
}

/*
 * The halftrace yields the square root if the degree of the
 * irreducible polynomial is odd.
 */
int
b2n_halftrace(b2n_ptr ho, b2n_ptr a, b2n_ptr p)
{
	int	i, m = b2n_sigbit(p) - 1;
	b2n_t	h;

	b2n_init(h);
	if (b2n_set(h, a))
		goto fail;

	for (i = 0; i < (m - 1) / 2; i++) {
		if (b2n_square(h, h))
			goto fail;
		if (b2n_mod(h, h, p))
			goto fail;
		if (b2n_square(h, h))
			goto fail;
		if (b2n_mod(h, h, p))
			goto fail;

		if (b2n_add(h, h, a))
			goto fail;
	}

	B2N_SWAP(ho, h);

	b2n_clear(h);
	return 0;

fail:
	b2n_clear(h);
	return -1;
}

/*
 * Solving the equation: y**2 + y = b in GF(2**m) where ip is the
 * irreducible polynomial. If m is odd, use the half trace.
 */
int
b2n_sqrt(b2n_ptr zo, b2n_ptr b, b2n_ptr ip)
{
	int	i, m = b2n_sigbit(ip) - 1;
	b2n_t	w, p, temp, z;

	if (!b2n_cmp_null(b))
		return b2n_set_null(z);

	if (m & 1)
		return b2n_halftrace(zo, b, ip);

	b2n_init(z);
	b2n_init(w);
	b2n_init(p);
	b2n_init(temp);

	do {
		if (b2n_random(p, m))
			goto fail;
		if (b2n_set_null(z))
			goto fail;
		if (b2n_set(w, p))
			goto fail;

		for (i = 1; i < m; i++) {
			if (b2n_square(z, z))	/* z**2 */
				goto fail;
			if (b2n_mod(z, z, ip))
				goto fail;

			if (b2n_square(w, w))	/* w**2 */
				goto fail;
			if (b2n_mod(w, w, ip))
				goto fail;

			if (b2n_mul(temp, w, b))	/* w**2 * b */
				goto fail;
			if (b2n_mod(temp, temp, ip))
				goto fail;
			if (b2n_add(z, z, temp))	/* z**2 + w**2 + b */
				goto fail;

			if (b2n_add(w, w, p))	/* w**2 + p */
				goto fail;
		}
	} while (!b2n_cmp_null(w));

	B2N_SWAP(zo, z);

	b2n_clear(w);
	b2n_clear(p);
	b2n_clear(temp);
	b2n_clear(z);
	return 0;

fail:
	b2n_clear(w);
	b2n_clear(p);
	b2n_clear(temp);
	b2n_clear(z);
	return -1;
}

/*
 * Low-level function to speed up scalar multiplication with
 * elliptic curves.
 * Multiplies a normal number by 3.
 */

/* Normal addition behaves as Z_{2**n} and not F_{2**n}.  */
int
b2n_nadd(b2n_ptr d0, b2n_ptr a0, b2n_ptr b0)
{
	int	i, carry;
	b2n_ptr	a, b;
	b2n_t	d;

	if (!b2n_cmp_null(a0))
		return b2n_set(d0, b0);

	if (!b2n_cmp_null(b0))
		return b2n_set(d0, a0);

	b2n_init(d);
	a = B2N_MAX(a0, b0);
	b = B2N_MIN(a0, b0);

	if (b2n_resize(d, a->chunks + 1)) {
		b2n_clear(d);
		return -1;
	}
	for (carry = i = 0; i < b->chunks; i++) {
		d->limp[i] = a->limp[i] + b->limp[i] + carry;
		carry = (d->limp[i] < a->limp[i] ? 1 : 0);
	}

	for (; i < a->chunks && carry; i++) {
		d->limp[i] = a->limp[i] + carry;
		carry = (d->limp[i] < a->limp[i] ? 1 : 0);
	}

	if (i < a->chunks)
		memcpy(d->limp + i, a->limp + i,
		    CHUNK_BYTES * (a->chunks - i));

	d->dirty = 1;
	B2N_SWAP(d0, d);

	b2n_clear(d);
	return 0;
}

int
b2n_3mul(b2n_ptr d0, b2n_ptr e)
{
	b2n_t	d;

	b2n_init(d);
	if (b2n_lshift(d, e, 1))
		goto fail;

	if (b2n_nadd(d0, d, e))
		goto fail;

	b2n_clear(d);
	return 0;

fail:
	b2n_clear(d);
	return -1;
}