167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
# coding=utf-8
|
|
|
|
import time
|
|
import pymysql
|
|
from pymysql.cursors import DictCursor
|
|
|
|
|
|
class SQLRequest:
|
|
|
|
def __init__(self, host, user, password, port, charset, connect_timeout, db_type, sql):
|
|
"""
|
|
:param host: 主机名
|
|
:param user: 用户
|
|
:param password: 密码
|
|
:param port: 端口
|
|
:param charset: 字符集
|
|
:param connect_timeout: 超时时间 单位s
|
|
:param db_type: shell命令
|
|
:param sql: sql脚本
|
|
"""
|
|
self.host = host
|
|
self.user = user
|
|
self.password = password
|
|
self.port = port
|
|
self.charset = charset
|
|
self.connect_timeout = connect_timeout
|
|
self.db_type = db_type
|
|
self.sql = sql
|
|
# 请求与响应
|
|
self.response_headers = None # type: dict
|
|
self.response_body = None # type: dict
|
|
self.request_headers = None # type: dict
|
|
self.request_body = None # type: dict
|
|
self.affected_rows = 0 # 影响行数
|
|
self.all_rows = () # 执行结果
|
|
# 执行耗时 ms
|
|
self.elapsed_time = 0
|
|
|
|
def send(self):
|
|
self.execute(query=self.sql)
|
|
|
|
def execute(self, query, args=None):
|
|
# 分割多行sql脚本
|
|
sql_list = self._sql_split(query=query)
|
|
# send request
|
|
_start_clock = time.clock()
|
|
with self as cursor:
|
|
for sql in sql_list:
|
|
self.affected_rows = cursor.execute(query=sql, args=args)
|
|
self.all_rows = cursor.fetchall()
|
|
# recv response
|
|
_end_clock = time.clock()
|
|
self.elapsed_time = int((_end_clock - _start_clock) * 1000) + 1
|
|
self._handle_response()
|
|
|
|
def _sql_split(self, query):
|
|
"""
|
|
将多行sql文本分割为单行sql文本集合
|
|
:param query: 单行/多行sql文本
|
|
:type query: str
|
|
:return: 分割后的sql列表
|
|
:rtype: list
|
|
"""
|
|
sql_list = []
|
|
lines = []
|
|
# 删除注释行 空白行
|
|
for line in query.split('\n'):
|
|
if str.strip(line).startswith("--"):
|
|
continue
|
|
elif str.strip(line) == '':
|
|
continue
|
|
else:
|
|
lines.append(line)
|
|
for line in '\n'.join(lines).split(';'):
|
|
if '\n' in line:
|
|
line = line.replace('\n', ' ')
|
|
elif str.strip(line) == '':
|
|
continue
|
|
sql_list.append(line)
|
|
return sql_list
|
|
|
|
def __enter__(self):
|
|
"""
|
|
:return: 游标
|
|
:rtype: DictCursor
|
|
"""
|
|
self.connection = pymysql.connect(
|
|
host=self.host,
|
|
user=self.user,
|
|
password=self.password,
|
|
port=self.port,
|
|
charset=self.charset,
|
|
connect_timeout=self.connect_timeout
|
|
)
|
|
self.cursor = self.connection.cursor(cursor=DictCursor)
|
|
return self.cursor
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if exc_type:
|
|
self.connection.rollback()
|
|
else:
|
|
self.connection.commit()
|
|
self.cursor.close()
|
|
self.connection.close()
|
|
|
|
def _handle_response(self):
|
|
"""处理应答解析拿到 应答体 应答头 请求体 请求头"""
|
|
self._handle_response_headers()
|
|
self._handle_request_body()
|
|
self._handle_request_headers()
|
|
self._handle_response_body()
|
|
|
|
def _handle_response_headers(self):
|
|
self.response_headers = {
|
|
'host': self.host,
|
|
'port': self.port,
|
|
'user': self.user,
|
|
'password': self.password,
|
|
'connect_timeout': self.connect_timeout,
|
|
'sql': self.sql,
|
|
'charset': self.charset,
|
|
'db_type': self.db_type,
|
|
'elapsed_time': self.elapsed_time,
|
|
}
|
|
|
|
def _handle_request_body(self):
|
|
self.request_body = {
|
|
'sql': self.sql,
|
|
}
|
|
|
|
def _handle_request_headers(self):
|
|
self.request_headers = {
|
|
'host': self.host,
|
|
'port': self.port,
|
|
'user': self.user,
|
|
'password': self.password,
|
|
'connect_timeout': self.connect_timeout,
|
|
'sql': self.sql,
|
|
'charset': self.charset,
|
|
'db_type': self.db_type,
|
|
}
|
|
|
|
def _handle_response_body(self):
|
|
self.response_body = {
|
|
'affected_rows': self.affected_rows,
|
|
'all_rows': self.all_rows,
|
|
}
|
|
|
|
|
|
def make_request(host, port, user, password, connect_timeout, db_type, charset, sql):
|
|
"""
|
|
:param host: 主机名
|
|
:param port: 端口
|
|
:param user: 用户
|
|
:param password: 密码
|
|
:param connect_timeout: 超时时间
|
|
:param db_type: shell命令
|
|
:param charset: 字符集
|
|
:param sql: sql脚本
|
|
"""
|
|
return SQLRequest(host=host, port=int(port), user=user, password=password, connect_timeout=int(connect_timeout),
|
|
db_type=db_type, charset=charset, sql=sql)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass
|