/*
 * PAM module for MySQL 
 *
 * Original Version written by: Gunay ARSLAN <arslan@gunes.medyatext.com.tr>
 * This version by: James O'Kane <jo2y@midnightlinux.com>
 * Modifications by Steve Brown, <steve@electronic.co.uk>
 *
 */

#define _GNU_SOURCE

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <syslog.h>
#include <stdarg.h>
#include <alloca.h>
#include <string.h>

#include <mysql/mysql.h>

/*
 * here, we make definitions for the externally accessible functions
 * in this file (these definitions are required for static modules
 * but strongly encouraged generally) they are used to instruct the
 * modules include file to define their prototypes.
 */

#define PAM_SM_AUTH
#define PAM_SM_ACCOUNT
#define PAM_SM_SESSION
#define PAM_SM_PASSWORD

#define PAM_MODULE_NAME  "pam_mysql"
#define PLEASE_ENTER_PASSWORD "Password:"
/* #define DEBUG */

#include <security/pam_modules.h>
#include <security/pam_misc.h>

struct optionstruct {
	char host[257];
	char where[257];
	char database[17];
	char dbuser[17];
	char dbpasswd[17];
	char table[17];
	char usercolumn[17];
	char passwdcolumn[17];
	int crypt;
};				/*Max length for most MySQL fields is 16 */


/* Global Variables */
MYSQL *mysql_auth = NULL;

struct optionstruct options =
{
	"localhost",
	"",
	"mysql",
	"nobody",
	"",
	"user",
	"User",
	"Password",
	0
};

/* Prototypes */
int converse (pam_handle_t * pamh,
	      int nargs,
	      struct pam_message **message,
	      struct pam_response **response);

int _set_auth_tok (pam_handle_t * pamh,
		   int flags, int argc,
		   const char **argv);

int db_connect (MYSQL * auth_sql_server);
void db_close( void );
int askForPassword(pam_handle_t *pamh);

void db_close ( void )
{
	if ( mysql_auth == NULL )
	{
		return; /* closed already */
	}
	mysql_close(mysql_auth);
	mysql_auth = NULL;
}

/* MySQL access functions */
int db_connect (MYSQL * auth_sql_server)
{
	int retvalue = PAM_AUTH_ERR;

#ifdef DEBUG
	D (("called."));
#endif
	if ( mysql_auth != NULL )
		return PAM_SUCCESS;
	mysql_init (auth_sql_server);
	mysql_auth = mysql_real_connect (auth_sql_server,
					 options.host,
					 options.dbuser,
					 options.dbpasswd,
					 options.database, 0, NULL, 0);
					 
	if (mysql_auth != NULL) {
		if (!mysql_select_db (auth_sql_server, options.database)) {
			retvalue = PAM_SUCCESS;
		}
	}
	if ( retvalue != PAM_SUCCESS )
	{
		syslog(LOG_INFO, "pam_mysql: MySQL err %s\n", mysql_error(auth_sql_server));
	}
#ifdef DEBUG	
	D (("returning %i.",retvalue));
#endif
	return retvalue;
}

static int db_checkpasswd (MYSQL * auth_sql_server, const char *user, const char *passwd)
{
	char *sql;
	char *escapeUser;	/* User provided stuff MUST be escaped */
	char *encryptedPass;
	char *salt;
#ifdef DEBUG
	char *sqllog;
#endif
	MYSQL_RES *result;
	MYSQL_ROW row;
	int retvalue = PAM_AUTH_ERR;
#ifdef DEBUG
	D (("called."));
#endif

	sql = (char *) malloc (110 + strlen(user) + strlen(options.where));
	if ( !sql )
		return PAM_BUF_ERR;

	escapeUser = malloc(sizeof(char) * (strlen(user) * 2) + 1);
	if ( escapeUser == NULL )
	{
		syslog(LOG_ERR, "pam_mysql: Insufficient memory to allocate user escape strings");
		syslog(LOG_ERR, "pam_mysql: UNABLE TO AUTHENTICATE");
		return PAM_BUF_ERR;
	}
#ifdef HAVE_MYSQL_REAL_ESCAPE_STRING
	mysql_real_escape_string(auth_sql_server, escapeUser, user, strlen(user));
#else
	mysql_escape_string(escapeUser, user, strlen(user));
#endif	   
	sprintf(sql, "SELECT %s FROM %s WHERE %s='%s'",
		options.passwdcolumn,options.table,
		options.usercolumn,escapeUser);
	if ( strlen(options.where) > 0 )
	{
		sprintf(sql, "%s AND %s", sql, options.where);
	}
#ifdef DEBUG
	syslog(LOG_ERR, "pam_mysql: where clause = %s", options.where);
#endif

#ifdef DEBUG
/* These lines are commented even though its inside a ifdef DEBUG
  Why? Because it exposes user names and passwords into syslog and is
  a MASSIVE security hole. Uncomment only if you REALLY need to see
  the query. Remember, usernames and attmpted passwords will be
  exposed into syslog though!
*/
/*
	sqllog = malloc(sizeof(char) * strlen(sql) + strlen("pam_mysql: ") + 1);
	sprintf(sqllog, "pam_mysql: %s", sql);
	syslog(LOG_ERR,sqllog);
	free(sqllog);
*/
	D ((sql));
#endif
	mysql_query (auth_sql_server, sql);
	free (sql);
	result = mysql_store_result (auth_sql_server);
	if (!result) {
		syslog(LOG_ERR, mysql_error (auth_sql_server));
		return PAM_AUTH_ERR;
	}
	if (mysql_num_rows (result) == 1) {
		encryptedPass = malloc(sizeof(char) * (strlen(passwd) + 31) + 1);
		/* Grab the password from RESULT_SET. */
		row = mysql_fetch_row(result);
		if (row == NULL) {
			syslog(LOG_ERR, mysql_error (auth_sql_server));
			return PAM_AUTH_ERR;
		}
		switch(options.crypt) {
			/* PLAIN */
			case 0: strcpy(encryptedPass, passwd);
				break;
			/* ENCRYPT */
			case 1: if (strlen(row[0]) < 12) { /* strlen() < 12 isn't a valid encrypted password. */
					syslog(LOG_ERR, "pam_mysql: select returned an invalid encrypted password" );
					break;
				}
				salt = malloc(sizeof(char) * strlen(row[0]) + 1);
				if (strncmp("$1$", row[0], 3) == 0) { /* A MD5 salt starts with "$1$" and is 12 bytes long. */
					strncpy(salt, row[0], 12);
					salt[12] = '\0';
				} else { /* If it's not MD5, assume DES and a 2 bytes salt. */
					strncpy(salt, row[0], 2);
					salt[2] = '\0';
				}
				strcpy(encryptedPass, crypt(passwd, salt));
				free (salt);
				break;
			/* PASSWORD */
			case 2: make_scrambled_password(encryptedPass, passwd);
				break;
		}
		if (!strcmp(row[0], encryptedPass)) {
			retvalue = PAM_SUCCESS;
		}
		free (encryptedPass);
	} else {
		syslog(LOG_ERR, "pam_mysql: select returned more than one result");
	}
#ifdef DEBUG
	D (("returning %i.",retvalue));
#endif
	return retvalue;
}


/* Global PAM functions stolen from other modules */


int converse(pam_handle_t *pamh, int nargs
		    , struct pam_message **message
		    , struct pam_response **response)
{
    int retval;
    struct pam_conv *conv;

    retval = pam_get_item( pamh, PAM_CONV, (const void **) &conv ) ; 
    if ( retval == PAM_SUCCESS ) 
    {
	retval = conv->conv(nargs, ( const struct pam_message ** ) message
			    , response, conv->appdata_ptr);
	if ((retval != PAM_SUCCESS) && (retval != PAM_CONV_AGAIN)) 
	{
	    syslog(LOG_DEBUG, "pam_mysql: conversation failure [%s]"
		     , pam_strerror(pamh, retval));
	}
    } 
    else 
    {
	syslog(LOG_ERR, "pam_mysql: couldn't obtain coversation function [%s]"
		 , pam_strerror(pamh, retval));
    }
    return retval;                  /* propagate error status */
}

int askForPassword(pam_handle_t *pamh)
{
	struct pam_message msg[1], *mesg[1];
	struct pam_response *resp=NULL;
	char *prompt=NULL;
	int i=0;
	int retval;

	prompt = malloc(strlen(PLEASE_ENTER_PASSWORD));
	if (prompt == NULL) 
	{
		syslog(LOG_ERR,"pam_mysql: askForPassword(), out of memory!?");
		return PAM_BUF_ERR;
	} 
	else 
	{
		sprintf(prompt, PLEASE_ENTER_PASSWORD);
		msg[i].msg = prompt;
	}
	msg[i].msg_style = PAM_PROMPT_ECHO_OFF;
	mesg[i] = &msg[i];

	retval = converse(pamh, ++i, mesg, &resp);
	if (prompt) 
	{
	    _pam_overwrite(prompt);
	    _pam_drop(prompt);
	}
	if (retval != PAM_SUCCESS) 
	{
	    if (resp != NULL)
		_pam_drop_reply(resp,i);
	    return ((retval == PAM_CONV_AGAIN)
		    ? PAM_INCOMPLETE:PAM_AUTHINFO_UNAVAIL);
	}

	/* we have a password so set AUTHTOK
	 */
	return pam_set_item(pamh, PAM_AUTHTOK, resp->resp);
}


/* PAM Authentication functions */

PAM_EXTERN int pam_sm_authenticate (pam_handle_t * pamh,
				    int flags,
				    int argc,
				    const char **argv)
{
	int retval, i;
	const char *user;
	char *passwd = NULL;
	MYSQL auth_sql_server;

#ifdef DEBUG
	D (("called."));
#endif

/* Parse arguments taken from pam_listfile.c */
	for (i = 0; i < argc; i++) {
		char *junk;
		char mybuf[256], myval[256];
		char *mj;

		junk = (char *) malloc (strlen (argv[i]) + 1);
		if (junk == NULL) {
#ifdef DEBUG
			D (("returning PAM_BUF_ERR."));
			return PAM_BUF_ERR;
#endif
		}
		strcpy (junk, argv[i]);
		if ((strchr (junk, (int) '=') != NULL)) {
			strncpy (mybuf, strtok (junk, "="), 255);
			strncpy (myval, strtok (NULL, "="), 255);
			free (junk);
			if (!strcasecmp ("host", mybuf)) {
				strncpy (options.host, myval, 255);
				D (("host changed."));
			} else if (!strcasecmp ("where", mybuf)) {
				while ( (mj = strtok(NULL,"=")) != NULL )
				{
					strcat(myval, "=");
					strcat(myval, mj);
				}
				strncpy (options.where, myval, 256);
				D (("where changed."));
#ifdef DEBUG
				syslog(LOG_ERR, "pam_mysql: where now is %s", options.where);
#endif
			} else if (!strcasecmp ("db", mybuf)) {
				strncpy (options.database, myval, 16);
				D (("database changed."));
			} else if (!strcasecmp ("user", mybuf)) {
				strncpy (options.dbuser, myval, 16);
				D (("dbuser changed."));
			} else if (!strcasecmp ("passwd", mybuf)) {
				strncpy (options.dbpasswd, myval, 16);
				D (("dbpasswd changed."));
			} else if (!strcasecmp ("table", mybuf)) {
				strncpy (options.table, myval, 16);
				D (("table changed."));
			} else if (!strcasecmp ("usercolumn", mybuf)) {
				strncpy (options.usercolumn, myval, 16);
				D (("usercolumn changed."));
			} else if (!strcasecmp ("passwdcolumn", mybuf)) {
				strncpy (options.passwdcolumn, myval, 16);
				D (("passwdcolumn changed."));
			} else if (!strcasecmp ("crypt", mybuf)) {
				if ((!strcmp (myval, "1")) ||
				    (!strcasecmp (myval, "Y"))) {
					options.crypt = 1;
				} else if ((!strcmp(myval, "2")) ||
				    (!strcasecmp(myval, "mysql"))) {
				    	options.crypt = 2;
				}
				else {
					options.crypt = 0;
				}
#ifdef DEBUG
				D (("crypt changed."));
#endif
			} else {
#ifdef DEBUG
				D (("Unknown option: %s=%s", mybuf, myval));
#endif
			}
		} else {
			char *error = (char *) malloc (20 + strlen(junk));
			if ( error )
			{
				sprintf (error, "Unknown option: %s", junk);
#ifdef DEBUG
				D ((error));
#endif
			}
		}
	}/* for loop */

	/* Get User */

	retval = pam_get_user (pamh, &user, NULL);
	if (retval != PAM_SUCCESS || user == NULL) {
		syslog (LOG_ERR, "pam_mysql: no user specified");
#ifdef DEBUG
		D (("returning."));
#endif
		return PAM_USER_UNKNOWN;
	} 
	
	retval = pam_get_item(pamh, PAM_AUTHTOK, (const void **) &passwd);
	if ( passwd == NULL )
	{
		askForPassword(pamh);
	}
	retval = pam_get_item(pamh, PAM_AUTHTOK, (const void **)&passwd);

	if ( passwd == NULL )
		return PAM_AUTHINFO_UNAVAIL;
		
	if ((retval = db_connect (&auth_sql_server)) != PAM_SUCCESS) {
		db_close();
		D (("returning %i after db_connect.",retval));
		return retval;
	}
	if ((retval = db_checkpasswd (&auth_sql_server, user, passwd)) != PAM_SUCCESS) {
		D (("returning %i after db_checkpasswd.",retval));
		db_close();
		return retval;
	}
#ifdef DEBUG
	D (("returning %i.",retval));
#endif
	db_close();
	return retval;

}/* pam_sm_authenticate */


/* --- account management functions --- */
PAM_EXTERN int pam_sm_acct_mgmt (pam_handle_t * pamh, int flags, int argc
				 ,const char **argv)
{
#ifdef DEBUG
	syslog (LOG_INFO, "pam_mysql: acct_mgmt called but not implemented. Dont panic though :)");
#endif
	return PAM_SUCCESS;
}

PAM_EXTERN
int pam_sm_setcred(pam_handle_t *pamh,int flags,int argc
		   ,const char **argv)
{
#ifdef DEBUG
     syslog(LOG_INFO, "pam_mysql: setcred called but not implemented.");
#endif
     return PAM_SUCCESS;
}

/* --- password management --- */

PAM_EXTERN
int pam_sm_chauthtok(pam_handle_t *pamh,int flags,int argc
		     ,const char **argv)
{
     syslog(LOG_INFO, "pam_mysql: chauthtok called but not implemented. Password NOT CHANGED!");
     return PAM_SUCCESS;
}

/* --- session management --- */

PAM_EXTERN
int pam_sm_open_session(pam_handle_t *pamh,int flags,int argc
			,const char **argv)
{
#ifdef DEBUG
     syslog(LOG_INFO, "pam_mysql: open_session called but not implemented.");
#endif
    return PAM_SUCCESS;
}

PAM_EXTERN
int pam_sm_close_session(pam_handle_t *pamh,int flags,int argc
			 ,const char **argv)
{
     syslog(LOG_INFO, "pam_mysql: close_session called but not implemented.");
     return PAM_SUCCESS;
}

/* end of module definition */

#ifdef PAM_STATIC

/* static module data */

struct pam_module _pam_permit_modstruct = {
    "pam_permit",
    pam_sm_authenticate,
    pam_sm_setcred,
    pam_sm_acct_mgmt,
    pam_sm_open_session,
    pam_sm_close_session,
    pam_sm_chauthtok
};

#endif
