Updated:

3 minute read

개요


예제

  • 코드
    •  import pymysql
       import time
       import threading
       import unittest
              
              
       class MySQL:
              
           def __init__(self,
                        host,
                        port,
                        user,
                        password,
                        database=None,
                        charset='utf8',
                        autocommit=True):
               self.host = host
               self.port = port
               self.user = user
               self.password = password
               self.database = database
               self.charset = charset
               self.autocommit = autocommit
              
               self.connection_infos = {}
               self.lock = threading.Lock()
              
           def __del__(self):
               for value in self.connection_infos.values():
                   value.close()
              
               self.connection_infos.clear()
              
           def get_connection(self):
               with self.lock:
                   if (threading.get_ident() in self.connection_infos) == False:
                       self.connection_infos[threading.get_ident()] = pymysql.connect(
                           host=self.host,
                           port=self.port,
                           user=self.user,
                           password=self.password,
                           database=self.database,
                           charset=self.charset,
                           autocommit=self.autocommit)
              
                   self.connection_infos[threading.get_ident()].ping()
              
                   return self.connection_infos[threading.get_ident()]
              
           def get_cursor(self, cursor=None):
               return self.get_connection().cursor(cursor)
              
              
       class TestMySQL(unittest.TestCase):
              
           @classmethod
           def setUpClass(cls):
               cls.database_name = 'temp_' + str(int(time.time()))
               cls.table_name = 'temp'
              
               cls.mysql = MySQL(host='127.0.0.1',
                                 port=3306,
                                 user='root',
                                 password='root')
              
               cls.mysql.get_cursor().execute('CREATE DATABASE IF NOT EXISTS ' +
                                              cls.database_name + ';')
              
               cls.mysql.get_connection().select_db(cls.database_name)
              
               cls.mysql.get_cursor().execute('DROP TABLE IF EXISTS ' +
                                              cls.table_name + ';')
               cls.mysql.get_cursor().execute('CREATE TABLE IF NOT EXISTS ' +
                                              cls.table_name +
                                              '(id INT(10), name VARCHAR(50));')
              
               return
              
           @classmethod
           def tearDownClass(cls):
               cls.mysql.get_cursor().execute('DROP DATABASE IF EXISTS ' +
                                              cls.database_name + ';')
               return
              
           def setUp(self):
               self.cursor = self.mysql.get_cursor()
               return
              
           def tearDown(self):
               self.cursor.execute('DELETE FROM ' + self.table_name + ';')
               self.cursor.close()
               return
              
           def worker(self, mysql, count):
               mysql.get_connection().select_db(self.database_name)
               cursor = mysql.get_cursor()
              
               for i in range(0, count):
                   query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(1, \'a\');'
                   self.assertEqual(cursor.execute(query), 1)
              
           def test_multi_thread(self):
               threads = []
              
               count = 100
               thread_count = 100
               for i in range(0, thread_count):
                   t = threading.Thread(target=self.worker, args=(self.mysql, count))
                   t.start()
                   threads.append(t)
              
               for t in threads:
                   t.join()
              
               query = 'SELECT * FROM ' + self.table_name + ' ORDER BY id ASC;'
               self.assertEqual(self.cursor.execute(query), thread_count * count)
              
           def test_insert(self):
               query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(1, \'a\');'
               self.assertEqual(self.cursor.execute(query), 1)
              
               query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(2, \'b\'), (3, \'c\');'
               self.assertEqual(self.cursor.execute(query), 2)
              
               query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(%s, %s);'
               args = (4, 'd')
               self.assertEqual(
                   self.cursor.execute(query, args),
                   1,
               )
              
               query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(%s, %s);'
               args = [[5, 'e'], [6, 'f'], [7, 'g']]
               self.assertEqual(self.cursor.executemany(query, args), 3)
              
           def test_select(self):
               self.test_insert()
              
               query = 'SELECT * FROM ' + self.table_name + ' ORDER BY id ASC;'
              
               self.assertEqual(self.cursor.execute(query), 7)
              
               self.assertEqual(self.cursor.description[0][0], 'id')
               self.assertEqual(self.cursor.description[1][0], 'name')
               self.assertEqual(self.cursor.rowcount, 7)
              
               id, name = self.cursor.fetchone()
               self.assertEqual(id, 1)
               self.assertEqual(name, 'a')
              
               rows = {2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g'}
               for id, name in self.cursor.fetchmany(2):
                   self.assertEqual(name, rows[id])
              
               for id, name in self.cursor.fetchall():
                   self.assertEqual(name, rows[id])
              
               return
              
           def test_transaction_commit(self):
               mysql_another = MySQL(host='127.0.0.1',
                                     port=3306,
                                     user='root',
                                     password='root',
                                     database=self.database_name)
               cursor_another = mysql_another.get_cursor()
              
               self.mysql.get_connection().begin()
              
               self.assertEqual(
                   self.cursor.execute('INSERT INTO ' + self.table_name +
                                       '(id, name) VALUES(1, \'a\');'), 1)
              
               self.assertEqual(
                   self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
               self.assertEqual(self.cursor.rowcount, 1)
              
               self.assertEqual(
                   cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                   0)
               self.assertEqual(cursor_another.rowcount, 0)
              
               self.mysql.get_connection().commit()
              
               self.assertEqual(
                   self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
               self.assertEqual(self.cursor.rowcount, 1)
              
               self.assertEqual(
                   cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                   1)
               self.assertEqual(cursor_another.rowcount, 1)
              
               self.assertEqual(
                   self.cursor.execute('INSERT INTO ' + self.table_name +
                                       '(id, name) VALUES(2, \'b\');'), 1)
              
               self.assertEqual(
                   self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 2)
               self.assertEqual(self.cursor.rowcount, 2)
              
               self.assertEqual(
                   cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                   2)
               self.assertEqual(cursor_another.rowcount, 2)
              
           def test_transaction_rollback(self):
               mysql_another = MySQL(host='127.0.0.1',
                                     port=3306,
                                     user='root',
                                     password='root',
                                     database=self.database_name)
               cursor_another = mysql_another.get_cursor()
              
               self.mysql.get_connection().begin()
              
               self.assertEqual(
                   self.cursor.execute('INSERT INTO ' + self.table_name +
                                       '(id, name) VALUES(1, \'a\');'), 1)
              
               self.assertEqual(
                   self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
               self.assertEqual(self.cursor.rowcount, 1)
              
               self.assertEqual(
                   cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                   0)
               self.assertEqual(cursor_another.rowcount, 0)
              
               self.mysql.get_connection().rollback()
              
               self.assertEqual(
                   self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 0)
               self.assertEqual(self.cursor.rowcount, 0)
              
               self.assertEqual(
                   cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                   0)
               self.assertEqual(cursor_another.rowcount, 0)
              
               self.assertEqual(
                   self.cursor.execute('INSERT INTO ' + self.table_name +
                                       '(id, name) VALUES(2, \'b\');'), 1)
              
               self.assertEqual(
                   self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
               self.assertEqual(self.cursor.rowcount, 1)
              
               self.assertEqual(
                   cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                   1)
               self.assertEqual(cursor_another.rowcount, 1)
              
           def test_mogrify(self):
               query = 'SELECT * FROM test ORDER BY id ASC;'
               self.assertEqual(self.cursor.mogrify(query), query)
              
               query = 'INSERT INTO test(id, name) VALUES(%s, %s);'
               args = (1, 'a')
               self.assertEqual(self.cursor.mogrify(query, args),
                                'INSERT INTO test(id, name) VALUES(1, \'a\');')
              
              
       if __name__ == '__main__':
           print('main call')
      
  • 실행 결과
    • python ./main.py
      •    main call
        
    • python -m unittest ./main.py
      •    ......
           ----------------------------------------------------------------------
           Ran 6 tests in 1.659s
                    
           OK