4242import json
4343import requests
4444import syslog
45- import sqlite3
4645import passlib .hash
4746import time
4847import traceback
4948import datetime
5049import yaml
50+ import sqlite3
51+ import mysql .connector
5152import re
5253
54+ SQLite = True
5355
5456def _get_config (argv ):
5557 """
@@ -59,6 +61,7 @@ def _get_config(argv):
5961 :param argv:
6062 :return: dictionary with the parameters
6163 """
64+ global SQLite
6265 config = {}
6366 argv .pop (0 )
6467 if len (argv ) == 1 and "config_file" in argv [0 ]:
@@ -78,20 +81,17 @@ def _get_config(argv):
7881 config ["users" ] = []
7982 # SQL Connection type/default
8083 if config .get ("mysql" ) is not None :
84+ SQLite = False
8185 mysql_settings = re .match ("mysql://([^:]+):([^@]+)@([^:/]+):([0-9]+)/(.+)" , config .get ("mysql" ))
8286 config ["sql" ] = {
83- 'lite' : False ,
8487 'user' : mysql_settings .group (1 ),
8588 'password' : mysql_settings .group (2 ),
8689 'host' : mysql_settings .group (3 ),
8790 'port' : mysql_settings .group (4 ),
8891 'database' : mysql_settings .group (5 )
8992 }
9093 else :
91- config ["sql" ] = {
92- 'lite' : True ,
93- 'file' : config .get ("sqlfile" , "/etc/privacyidea/pam.sqlite" )
94- }
94+ config ["sql" ] = config .get ("sqlfile" , "/etc/privacyidea/pam.sqlite" )
9595 return config
9696
9797
@@ -246,10 +246,12 @@ def offline_refill(self, serial, password):
246246 startdb (self .sql )
247247 refilltoken = None
248248 # get all possible serial/tokens for a user
249- for row in c .execute ("SELECT refilltoken FROM refilltokens WHERE serial=?" ,
250- (serial , )):
249+ c .execute (sql_abstract ("SELECT refilltoken FROM refilltokens WHERE serial=?" ),
250+ (serial , ))
251+ for row in c .fetchall ():
251252 refilltoken = row [0 ]
252253 syslog .syslog ("Doing refill with token {0!s}" .format (refilltoken ))
254+
253255 closedb ()
254256
255257 if refilltoken :
@@ -484,7 +486,6 @@ def pam_sm_authenticate(pamh, flags, argv):
484486 return pamh .PAM_AUTHINFO_UNAVAIL
485487
486488 try :
487-
488489 if grace_time is not None :
489490 syslog .syslog (syslog .LOG_DEBUG ,
490491 "Grace period in minutes: %s " % (str (grace_time )))
@@ -494,7 +495,6 @@ def pam_sm_authenticate(pamh, flags, argv):
494495 rval = pamh .PAM_SUCCESS
495496
496497 if rval != pamh .PAM_SUCCESS :
497-
498498 # Check if user has tokens
499499 Auth .check_user_tokens (Auth .user )
500500
@@ -571,15 +571,18 @@ def check_offline_otp(sql_params, user, otp, window=10, refill=True):
571571 # get all possible serial/tokens for a user
572572 serials = []
573573 matching_serial = None
574- for row in c .execute ("SELECT serial, user FROM authitems WHERE user=?"
575- "GROUP by serial" , (user ,)):
574+
575+ c .execute (sql_abstract ("SELECT serial, user FROM authitems WHERE user=?"
576+ "GROUP by serial" ), (user ,))
577+ for row in c .fetchall ():
576578 serials .append (row [0 ])
577579
578580 for serial in serials :
579- for row in c .execute ("SELECT counter, user, otp, serial FROM authitems "
581+ c .execute ( sql_abstract ("SELECT counter, user, otp, serial FROM authitems "
580582 "WHERE user=? and serial=? ORDER by counter "
581- "LIMIT ?" ,
582- (user , serial , window )):
583+ "LIMIT ?" ),
584+ (user , serial , window ))
585+ for row in c .fetchall ():
583586 hash_value = row [2 ]
584587 if passlib .hash .pbkdf2_sha512 .verify (otp , hash_value ):
585588 res = True
@@ -589,7 +592,7 @@ def check_offline_otp(sql_params, user, otp, window=10, refill=True):
589592
590593 # We found a matching password, so we remove the old entries
591594 if res :
592- c .execute ("DELETE from authitems WHERE counter <= ? and serial = ?" ,
595+ c .execute (sql_abstract ( "DELETE from authitems WHERE counter <= ? and serial = ?" ) ,
593596 (matching_counter , matching_serial ))
594597
595598 closedb ()
@@ -624,17 +627,17 @@ def save_auth_item(sql_params, user, serial, tokentype, authitem):
624627 tokenowner = offline .get ("username" )
625628 for counter , otphash in offline .get ("response" ).items ():
626629 # Insert the OTP hash
627- c .execute ("INSERT INTO authitems (counter, user, serial,"
628- "tokenowner, otp) VALUES (?,?,?,?,?)" ,
630+ c .execute (sql_abstract ( "INSERT INTO authitems (counter, user, serial,"
631+ "tokenowner, otp) VALUES (?,?,?,?,?)" ) ,
629632 (counter , user , serial , tokenowner , otphash ))
630633
631634 refilltoken = offline .get ("refilltoken" )
632635 # delete old refilltoken
633636 try :
634- c .execute (' DELETE FROM refilltokens WHERE serial=?' , (serial ,))
637+ c .execute (sql_abstract ( " DELETE FROM refilltokens WHERE serial=?" ) , (serial ,))
635638 except sqlite3 .OperationalError :
636639 pass
637- c .execute ("INSERT INTO refilltokens (serial, refilltoken) VALUES (?,?)" ,
640+ c .execute (sql_abstract ( "INSERT INTO refilltokens (serial, refilltoken) VALUES (?,?)" ) ,
638641 (serial , refilltoken ))
639642
640643 closedb ()
@@ -660,11 +663,12 @@ def check_last_history(sql_params, user, rhost, grace_time, window=10):
660663 res = False
661664 events = []
662665
663- for row in c .execute ("SELECT user, rhost, serial, last_success, last_error "
666+ c .execute ( sql_abstract ("SELECT user, rhost, serial, last_success, last_error "
664667 "FROM history "
665668 "WHERE user=? AND rhost=? ORDER by last_success "
666- "LIMIT ?" ,
667- (user , rhost , window )):
669+ "LIMIT ?" ),
670+ (user , rhost , window ))
671+ for row in c .fetchall ():
668672 events .append (row )
669673
670674 if len (events )> 0 :
@@ -714,37 +718,35 @@ def save_history_item(sql_params, user, rhost, serial, success):
714718 __name__ , ("success" if success else "error" )))
715719 if success :
716720 # Insert the Event
717- c .execute ("INSERT OR REPLACE INTO history (user, rhost, serial,"
718- "error_counter, last_success) VALUES (?,?,?,?,?)" ,
721+ c .execute (sql_abstract ( " REPLACE INTO history (user, rhost, serial,"
722+ "error_counter, last_success) VALUES (?,?,?,?,?)" ) ,
719723 (user , rhost , serial , 0 , datetime .datetime .now ()))
720724 else :
721725 # Insert the Event
722- c .execute ("UPDATE history SET error_counter = error_counter + 1, "
726+ c .execute (sql_abstract ( "UPDATE history SET error_counter = error_counter + 1, "
723727 " serial = ? , last_error = ? "
724- " WHERE user = ? AND rhost = ? " ,
728+ " WHERE user = ? AND rhost = ? " ) ,
725729 (serial , datetime .datetime .now (), user , rhost ))
726730
727731 syslog .syslog (syslog .LOG_DEBUG ,"Rows affected : %d " % c .rowcount )
728732 if c .rowcount == 0 :
729- c .execute ("INSERT INTO history (user, rhost, serial,"
730- "error_counter, last_error) VALUES (?,?,?,?,?)" ,
733+ c .execute (sql_abstract ( "INSERT INTO history (user, rhost, serial,"
734+ "error_counter, last_error) VALUES (?,?,?,?,?)" ) ,
731735 (user , rhost , serial , 1 , datetime .datetime .now ()))
732736
733737 closedb ()
734738
735-
736739# Start connection and create cursor
737740def startdb (sql_params ):
738741 global conn , c
739742 # Create connection
740- if sql_params ["lite" ]:
741- conn = sqlite3 .connect (sql_params ["file" ], detect_types = sqlite3 .PARSE_DECLTYPES )
742- # Create a cursor object
743- c = conn .cursor ()
743+ if SQLite :
744+ conn = sqlite3 .connect (sql_params , detect_types = sqlite3 .PARSE_DECLTYPES )
744745 else :
745- print ("Mysql" )
746- # mysql.connector.connect(**connection_config_dict)
746+ conn = mysql .connector .connect (** sql_params )
747747
748+ # Create a cursor object
749+ c = conn .cursor ()
748750 # Create table if does not exist
749751 _create_table ()
750752
@@ -767,7 +769,26 @@ def _create_table():
767769 c .execute ("CREATE TABLE IF NOT EXISTS refilltokens (serial text, refilltoken text)" )
768770 # create history table
769771 c .execute ("CREATE TABLE IF NOT EXISTS history "
770- "(user text , rhost text , serial text, error_counter int, "
772+ "(user varchar(50) , rhost varchar(50) , serial text, error_counter int, "
771773 "last_success timestamp, last_error timestamp)" )
772- c .execute ("CREATE UNIQUE INDEX idx_user "
773- "ON history (user, rhost);" )
774+ try :
775+ # create history table
776+ c .execute ("CREATE TABLE IF NOT EXISTS history "
777+ "(user text, rhost text, serial text, error_counter int, "
778+ "last_success timestamp, last_error timestamp)" )
779+ c .execute ("CREATE UNIQUE INDEX idx_user "
780+ "ON history (user, rhost);" )
781+ except mysql .connector .Error as err :
782+ if err .errno == mysql .connector .errorcode .ER_DUP_KEYNAME :
783+ pass
784+ else :
785+ raise
786+ except sqlite3 .OperationalError :
787+ pass
788+
789+ # Convert an SQLite statement to MySQL
790+ def sql_abstract (sql_statement ):
791+ if SQLite :
792+ return sql_statement
793+ else :
794+ return sql_statement .replace ('?' ,'%s' )
0 commit comments