/*
 * px-crypt.c
 *		Wrapper for various crypt algorithms.
 *
 * Copyright (c) 2001 Marko Kreen
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *	  notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *	  notice, this list of conditions and the following disclaimer in the
 *	  documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.	IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id: px-crypt.c,v 1.2 2001/09/23 04:12:44 momjian Exp $
 */

#include <postgres.h>
#include "px.h"
#include "px-crypt.h"


#ifndef PX_SYSTEM_CRYPT

static char *
run_crypt_des(const char *psw, const char *salt,
			 char *buf, unsigned len)
{
	char	   *res;

	res = px_crypt_des(psw, salt);
	if (strlen(res) > len - 1)
		return NULL;
	strcpy(buf, res);
	return buf;
}

static char *
run_crypt_md5(const char *psw, const char *salt,
			 char *buf, unsigned len)
{
	char	   *res;
	res = px_crypt_md5(psw, salt, buf, len);
	return res;
}

static char *
run_crypt_bf(const char *psw, const char *salt,
			char *buf, unsigned len)
{
	char	   *res;
	res = _crypt_blowfish_rn(psw, salt, buf, len);
	return res;
}

static struct
{
	char		*id;
	unsigned	id_len;
	char	   *(*crypt) (const char *psw, const char *salt,
									  char *buf, unsigned len);
}			px_crypt_list[] =

{
	{ "$2a$", 4, run_crypt_bf },
	{ "$2$", 3, NULL },							/* N/A */
	{ "$1$", 3, run_crypt_md5 },
	{ "_", 1, run_crypt_des },
	{ "", 0, run_crypt_des },
	{ NULL, 0, NULL }
};

char *
px_crypt(const char *psw, const char *salt, char *buf, unsigned len)
{
	int			i;

	for (i = 0; px_crypt_list[i].id; i++)
	{
		if (!px_crypt_list[i].id_len)
			break;
		if (!strncmp(salt, px_crypt_list[i].id, px_crypt_list[i].id_len))
			break;
	}

	if (px_crypt_list[i].crypt == NULL)
		return NULL;

	return px_crypt_list[i].crypt(psw, salt, buf, len);
}

#else							/* PX_SYSTEM_CRYPT */

extern char *crypt(const char *psw, const char *salt);

char *
px_crypt(const char *psw, const char *salt,
		 char *buf, unsigned len)
{
	char	   *res;

	res = crypt(psw, salt);
	if (!res || strlen(res) >= len)
		return NULL;
	strcpy(buf, res);
	return buf;
}
#endif

/*
 * salt generators
 */

struct generator {
	char *name;
	char *(*gen)(unsigned long count, const char *input, int size,
					char *output, int output_size);
	int input_len;
	int def_rounds;
	int min_rounds;
	int max_rounds;
};

static struct generator gen_list [] = {
	{ "des", _crypt_gensalt_traditional_rn, 2, 0, 0, 0 },
	{ "md5", _crypt_gensalt_md5_rn, 6, 0, 0, 0 },
	{ "xdes", _crypt_gensalt_extended_rn, 3, PX_XDES_ROUNDS, 1, 0xFFFFFF },
	{ "bf", _crypt_gensalt_blowfish_rn, 16, PX_BF_ROUNDS, 4, 31 },
	{ NULL, NULL, 0, 0, 0 }
};

uint
px_gen_salt(const char *salt_type, char *buf, int rounds)
{
	int i, res;
	struct generator *g;
	char *p;
	char rbuf[16];
	
	for (i = 0; gen_list[i].name; i++) {
		g = &gen_list[i];
		if (strcasecmp(g->name, salt_type) != 0)
			continue;

		if (g->def_rounds) {
			if (rounds == 0)
				rounds = g->def_rounds;
			
			if (rounds < g->min_rounds || rounds > g->max_rounds)
				return 0;
		}

		res = px_get_random_bytes(rbuf, g->input_len);
		if (res != g->input_len)
			return 0;

		p = g->gen(rounds, rbuf, g->input_len, buf, PX_MAX_SALT_LEN);
		memset(rbuf, 0, sizeof(rbuf));
		
		return p != NULL ? strlen(p) : 0;
	}

	return 0;
}