101 lines
3.1 KiB
Plaintext
101 lines
3.1 KiB
Plaintext
|
#!/usr/bin/python3
|
||
|
|
||
|
import sqlite3
|
||
|
import os.path
|
||
|
import random
|
||
|
|
||
|
ops = ['=', '<>', '>', '>=', '<', '<=', 'IS', 'IS NOT', 'LIKE'];
|
||
|
|
||
|
def get_column_values(conn, table, column):
|
||
|
cursor = conn.execute('SELECT DISTINCT quote({}) FROM {}'.format(column, table))
|
||
|
rv = []
|
||
|
for row in cursor:
|
||
|
rv.append(row[0])
|
||
|
cursor.close()
|
||
|
return rv
|
||
|
|
||
|
def get_columns(conn, table):
|
||
|
cursor = conn.execute('SELECT * FROM {} LIMIT 1'.format(table))
|
||
|
rv = ['rowid'] + [description[0] for description in cursor.description]
|
||
|
cursor.close()
|
||
|
return rv
|
||
|
|
||
|
def generate_statement(conn, table, column_values, all_values):
|
||
|
names = list(column_values.keys())
|
||
|
|
||
|
num_fields = random.randint(1, 20)
|
||
|
|
||
|
query = 'SELECT '
|
||
|
for i in range(num_fields):
|
||
|
if i > 0:
|
||
|
query += ', '
|
||
|
|
||
|
query += 'quote(' + names[random.randint(0, len(names) - 1)] + ')'
|
||
|
query += ' FROM '
|
||
|
query += table
|
||
|
|
||
|
num_clauses = random.randint(0, 2)
|
||
|
for i in range(num_clauses):
|
||
|
if i == 0:
|
||
|
query += ' WHERE '
|
||
|
else:
|
||
|
query += ' AND '
|
||
|
|
||
|
field = names[random.randint(0, len(names) - 1)]
|
||
|
op = ops[random.randint(0, len(ops) - 1)]
|
||
|
values = column_values[field]
|
||
|
if random.randint(0, 1) == 0:
|
||
|
values = all_values
|
||
|
value = values[random.randint(0, len(values) - 1)]
|
||
|
query += field + ' ' + op + ' ' + str(value)
|
||
|
|
||
|
return query
|
||
|
|
||
|
def test_statement(conn, table, column_values, all_values):
|
||
|
query = generate_statement(conn, table, column_values, all_values)
|
||
|
|
||
|
gold = [row for row in conn.execute(query)]
|
||
|
print('{} rows: {}'.format(len(gold), query))
|
||
|
for parquet in ['nulls1', 'nulls2', 'nulls3']:
|
||
|
new_query = query.replace('nulls', parquet)
|
||
|
rv = [row for row in conn.execute(new_query)]
|
||
|
if gold != rv:
|
||
|
with open('testcase-cmds.txt', 'w') as f:
|
||
|
f.write('.load parquet/libparquet\n.testcase query\n.bail on\n{};\n.output\n'.format(new_query))
|
||
|
with open('testcase-expected.txt', 'w') as f:
|
||
|
for row in gold:
|
||
|
f.write('{}\n'.format(row))
|
||
|
with open('testcase-out.txt', 'w') as f:
|
||
|
for row in rv:
|
||
|
f.write('{}\n'.format(row))
|
||
|
|
||
|
raise ValueError('ruhroh')
|
||
|
|
||
|
|
||
|
|
||
|
def test_table(conn, table):
|
||
|
column_names = get_columns(conn, table)
|
||
|
print('Table {}: {}'.format(table, column_names))
|
||
|
column_values = {}
|
||
|
for name in column_names:
|
||
|
column_values[name] = get_column_values(conn, table, name)
|
||
|
|
||
|
random.seed(0)
|
||
|
all_values = []
|
||
|
for values in column_values.values():
|
||
|
all_values = all_values + values
|
||
|
for i in range(1000):
|
||
|
test_statement(conn, table, column_values, all_values)
|
||
|
|
||
|
def test_db(db_file, tables):
|
||
|
conn = sqlite3.connect(db_file)
|
||
|
conn.enable_load_extension(True)
|
||
|
conn.load_extension('parquet/libparquet.so')
|
||
|
conn.enable_load_extension(False)
|
||
|
for table in tables:
|
||
|
test_table(conn, table)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
db_file = os.path.abspath(os.path.join(__file__, '..', '..', 'test.db'))
|
||
|
test_db(db_file, ['nulls', 'no_nulls'])
|