api-automation-test/ApiAutomationTest/app/cores/case/sql/request.py

167 lines
4.9 KiB
Python

# coding=utf-8
import time
import pymysql
from pymysql.cursors import DictCursor
class SQLRequest:
def __init__(self, host, user, password, port, charset, connect_timeout, db_type, sql):
"""
:param host: 主机名
:param user: 用户
:param password: 密码
:param port: 端口
:param charset: 字符集
:param connect_timeout: 超时时间 单位s
:param db_type: shell命令
:param sql: sql脚本
"""
self.host = host
self.user = user
self.password = password
self.port = port
self.charset = charset
self.connect_timeout = connect_timeout
self.db_type = db_type
self.sql = sql
# 请求与响应
self.response_headers = None # type: dict
self.response_body = None # type: dict
self.request_headers = None # type: dict
self.request_body = None # type: dict
self.affected_rows = 0 # 影响行数
self.all_rows = () # 执行结果
# 执行耗时 ms
self.elapsed_time = 0
def send(self):
self.execute(query=self.sql)
def execute(self, query, args=None):
# 分割多行sql脚本
sql_list = self._sql_split(query=query)
# send request
_start_clock = time.clock()
with self as cursor:
for sql in sql_list:
self.affected_rows = cursor.execute(query=sql, args=args)
self.all_rows = cursor.fetchall()
# recv response
_end_clock = time.clock()
self.elapsed_time = int((_end_clock - _start_clock) * 1000) + 1
self._handle_response()
def _sql_split(self, query):
"""
将多行sql文本分割为单行sql文本集合
:param query: 单行/多行sql文本
:type query: str
:return: 分割后的sql列表
:rtype: list
"""
sql_list = []
lines = []
# 删除注释行 空白行
for line in query.split('\n'):
if str.strip(line).startswith("--"):
continue
elif str.strip(line) == '':
continue
else:
lines.append(line)
for line in '\n'.join(lines).split(';'):
if '\n' in line:
line = line.replace('\n', ' ')
elif str.strip(line) == '':
continue
sql_list.append(line)
return sql_list
def __enter__(self):
"""
:return: 游标
:rtype: DictCursor
"""
self.connection = pymysql.connect(
host=self.host,
user=self.user,
password=self.password,
port=self.port,
charset=self.charset,
connect_timeout=self.connect_timeout
)
self.cursor = self.connection.cursor(cursor=DictCursor)
return self.cursor
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self.connection.rollback()
else:
self.connection.commit()
self.cursor.close()
self.connection.close()
def _handle_response(self):
"""处理应答解析拿到 应答体 应答头 请求体 请求头"""
self._handle_response_headers()
self._handle_request_body()
self._handle_request_headers()
self._handle_response_body()
def _handle_response_headers(self):
self.response_headers = {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'connect_timeout': self.connect_timeout,
'sql': self.sql,
'charset': self.charset,
'db_type': self.db_type,
'elapsed_time': self.elapsed_time,
}
def _handle_request_body(self):
self.request_body = {
'sql': self.sql,
}
def _handle_request_headers(self):
self.request_headers = {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'connect_timeout': self.connect_timeout,
'sql': self.sql,
'charset': self.charset,
'db_type': self.db_type,
}
def _handle_response_body(self):
self.response_body = {
'affected_rows': self.affected_rows,
'all_rows': self.all_rows,
}
def make_request(host, port, user, password, connect_timeout, db_type, charset, sql):
"""
:param host: 主机名
:param port: 端口
:param user: 用户
:param password: 密码
:param connect_timeout: 超时时间
:param db_type: shell命令
:param charset: 字符集
:param sql: sql脚本
"""
return SQLRequest(host=host, port=int(port), user=user, password=password, connect_timeout=int(connect_timeout),
db_type=db_type, charset=charset, sql=sql)
if __name__ == '__main__':
pass