· 6 years ago · Sep 19, 2019, 04:18 AM
1import sys
2import unittest
3import mysql.connector
4import prestodb
5
6sys.path.append('../')
7from airflow_api import AirflowAPI
8from db_util import DBUtil
9from constants import PRESTO_DB_PORT,MYSQL_DB_PORT
10
11
12class TestPrestoToMySqlDag(unittest.TestCase):
13 """Integration test for presto to mysql transfer"""
14
15 mysql_conn = None
16 prest_conn = None
17
18
19 def setUp(self):
20 presto_catlog="blackhole"
21 presto_schema= "default"
22 mysql_database="mysql"
23 mysql_user="mysql"
24 mysql_password="mysql"
25
26 self.airflow_api = AirflowAPI()
27 self.minikube_ip = str(self.airflow_api.get_minikube_ip())
28 self.db_util = DBUtil()
29 self.airflow_api.add_presto_connection("presto-conn",presto_catlog
30 ,presto_schema)
31 self.airflow_api.add_mysql_connection("mysql-conn", mysql_database,
32 mysql_user, mysql_password)
33 self.mysql_conn = mysql.connector.connect(user=mysql_user,
34 password=mysql_password,
35 host=self.minikube_ip,
36 port=MYSQL_DB_PORT,
37 database=mysql_database,
38 use_pure=False)
39
40 self.prest_conn = prestodb.dbapi.connect(
41 host=self.minikube_ip,
42 port=PRESTO_DB_PORT,
43 user='admin',
44 catalog=presto_catlog,
45 schema=presto_schema,
46 )
47
48 create_mysql_table_sql = """
49 CREATE TABLE IF NOT EXISTS mysql_region (
50 name VARCHAR(50),count int(10)
51 );
52 """
53
54 self.db_util.create_table(self.mysql_conn,create_mysql_table_sql)
55
56 create_presto_table_sql = """
57 CREATE TABLE region (
58 name varchar
59 )
60 WITH (
61 split_count = 1,
62 pages_per_split = 1,
63 rows_per_page = 1,
64 page_processing_delay = '5s'
65 )"""
66
67
68 self.db_util.create_table(self.prest_conn,create_presto_table_sql)
69
70 insert_query_1 = "insert into region values('INDIA')"
71 self.db_util.insert_into_table(self.prest_conn,insert_query_1)
72
73 def test_presto_to_mysql_transfer(self):
74 """should transfer data from presto to mysql"""
75
76 execution_date = "2019-05-12T14:00:00+00:00"
77 dag_id = "presto_to_mysql"
78 self.airflow_api.trigger_dag(dag_id, execution_date)
79 is_running = True
80 while is_running:
81 is_running = self.airflow_api.is_dag_running(dag_id, execution_date)
82 self.assertEqual(is_running, False)
83 self.assertEqual(self.airflow_api.get_dag_status(dag_id,
84 execution_date), "success")
85
86 mysql_select_query = "SELECT name FROM mysql_region"
87 row_count=self.db_util.get_row_count(self.mysql_conn,mysql_select_query)
88 self.assertEqual(1, len(row_count))
89
90 def tearDown(self):
91 drop_mysql_table="drop table mysql_region"
92 drop_presto_table = "drop table region"
93 self.db_util.drop_table(self.mysql_conn,drop_mysql_table)
94 self.db_util.drop_table(self.prest_conn,drop_presto_table)
95 self.mysql_conn.close()
96 self.prest_conn.close()
97
98
99if __name__ == '__main__':
100 unittest.main()