diff --git a/tests/test-all b/tests/test-all index 594dc26..5e32be3 100755 --- a/tests/test-all +++ b/tests/test-all @@ -8,3 +8,4 @@ set -x "$here"/test-unsupported "$here"/test-supported "$here"/test-queries +"$here"/test-random diff --git a/tests/test-random b/tests/test-random new file mode 100755 index 0000000..8862615 --- /dev/null +++ b/tests/test-random @@ -0,0 +1,100 @@ +#!/usr/bin/python3 + +import sqlite3 +import os.path +import random + +ops = ['=', '<>', '>', '>=', '<', '<=', 'IS', 'IS NOT', 'LIKE']; + +def get_column_values(conn, table, column): + cursor = conn.execute('SELECT DISTINCT quote({}) FROM {}'.format(column, table)) + rv = [] + for row in cursor: + rv.append(row[0]) + cursor.close() + return rv + +def get_columns(conn, table): + cursor = conn.execute('SELECT * FROM {} LIMIT 1'.format(table)) + rv = ['rowid'] + [description[0] for description in cursor.description] + cursor.close() + return rv + +def generate_statement(conn, table, column_values, all_values): + names = list(column_values.keys()) + + num_fields = random.randint(1, 20) + + query = 'SELECT ' + for i in range(num_fields): + if i > 0: + query += ', ' + + query += 'quote(' + names[random.randint(0, len(names) - 1)] + ')' + query += ' FROM ' + query += table + + num_clauses = random.randint(0, 2) + for i in range(num_clauses): + if i == 0: + query += ' WHERE ' + else: + query += ' AND ' + + field = names[random.randint(0, len(names) - 1)] + op = ops[random.randint(0, len(ops) - 1)] + values = column_values[field] + if random.randint(0, 1) == 0: + values = all_values + value = values[random.randint(0, len(values) - 1)] + query += field + ' ' + op + ' ' + str(value) + + return query + +def test_statement(conn, table, column_values, all_values): + query = generate_statement(conn, table, column_values, all_values) + + gold = [row for row in conn.execute(query)] + print('{} rows: {}'.format(len(gold), query)) + for parquet in ['nulls1', 'nulls2', 'nulls3']: + new_query = query.replace('nulls', parquet) + rv = [row for row in conn.execute(new_query)] + if gold != rv: + with open('testcase-cmds.txt', 'w') as f: + f.write('.load parquet/libparquet\n.testcase query\n.bail on\n{};\n.output\n'.format(new_query)) + with open('testcase-expected.txt', 'w') as f: + for row in gold: + f.write('{}\n'.format(row)) + with open('testcase-out.txt', 'w') as f: + for row in rv: + f.write('{}\n'.format(row)) + + raise ValueError('ruhroh') + + + +def test_table(conn, table): + column_names = get_columns(conn, table) + print('Table {}: {}'.format(table, column_names)) + column_values = {} + for name in column_names: + column_values[name] = get_column_values(conn, table, name) + + random.seed(0) + all_values = [] + for values in column_values.values(): + all_values = all_values + values + for i in range(1000): + test_statement(conn, table, column_values, all_values) + +def test_db(db_file, tables): + conn = sqlite3.connect(db_file) + conn.enable_load_extension(True) + conn.load_extension('parquet/libparquet.so') + conn.enable_load_extension(False) + for table in tables: + test_table(conn, table) + +if __name__ == '__main__': + db_file = os.path.abspath(os.path.join(__file__, '..', '..', 'test.db')) + test_db(db_file, ['nulls', 'no_nulls'])