在搞公司的SQL查询(MySQL)平台时,需要对用户查询SQL进行条数限制,默认是在配置文件中配置一个“limit = 1000”这样的参数。最自然想到的就是对用户通过web传入的SQL做处理,默认加上limit参数。这样一来就有这么几个问题需要处理:
1. 如果用户自己传入了limit 10这样的条件怎么办?
2. 如果用户自己传入了limit 10,2这样的条件怎么办?
3. 如果用户的查询比较复杂,有多个子查询并带有limit怎么办?
4. 如果用户查询字段有`limit`(不加“时的SQL会报语法错误)、及表名有limit这样的关键字怎么办?
测试通过代码如下,提供一个处理函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import json import re def replace_limit(sql, limit): """ 依次查找并处理limit offset,然后把limit关键字替换为special_flag 全部处理完后再把special_flag替换回limit :param sql: :param limit: :return: """ special_flag = '-*-*-' def fun(new_sql): """ :return: sql """ upper_sql = new_sql.upper() start_index = upper_sql.find(' LIMIT ') + len(' LIMIT ') end_index = start_index for i in range(start_index, len(upper_sql)): if bool(re.match(r'^[0-9]|,| ', upper_sql[i])): end_index += 1 else: break limit_str = upper_sql[start_index:end_index].strip() # 输入limit值大于默认limit值就进行替换成默认limit值 if ',' in limit_str: offsets = limit_str.split(',') if int(offsets[-1]) > limit: limit_str = '{}, {}'.format(offsets[0], limit) else: if int(limit_str) > limit: limit_str = '{}'.format(limit) limit_str = ' ' + limit_str + ' ' new_sql = new_sql.replace( new_sql[start_index:end_index], limit_str, 1 ) new_sql = new_sql.replace( new_sql[start_index - len(' LIMIT '):start_index], special_flag, 1 ) return new_sql # 原sql没有limit则在最后加上limit,并return if re.search(r'limit\s.*\d.*', sql, re.IGNORECASE) is None: sql = sql.rstrip(';') + ' limit %s' % int(limit) + ';' return sql # 分析limit语句 for i in re.findall(' limit ', sql, re.IGNORECASE): sql = fun(sql) # 替换回limit关键字 sql = sql.replace(special_flag, ' limit ') return sql |
这个函数接收两个参数,SQL语句和默认limit限制值,在平台中是SQL是从前端获取来的,limit值是从配置文件获取来的。
大概逻辑如下:
1. 如果字段中有`limit`或limittest关键字就不需要处理。
2. 如果用户输入没有limit限制就加上默认limit限制,然后直接返回sql。
3. 如果用户输入有limit限制就进行判断用户输入值是否大于默认值,如果大于就替换成默认值,否则不改动。
4. 最后把替换过的关键字再替换回来。
在这里测试,就可以直接调用函数即可,如下:
1 2 |
sql = "select limitest as `limit` from test limit 10, 100;" print(replace_limit(sql, 20)) |
结果如下:
1 |
select limitest as `limit` from test limit 10, 20; |
可以看到结果满足我们的需求,由于用户输入值大于默认值就替换成了默认limit值。