#!/usr/bin/python3 import sqlite3 import os.path import random ops = ['=', '<>', '>', '>=', '<', '<=', 'IS', 'IS NOT', 'LIKE']; def get_column_values(conn, table, column): cursor = conn.execute('SELECT DISTINCT quote({}) FROM {}'.format(column, table)) rv = [] for row in cursor: rv.append(row[0]) cursor.close() return rv def get_columns(conn, table): cursor = conn.execute('SELECT * FROM {} LIMIT 1'.format(table)) rv = ['rowid'] + [description[0] for description in cursor.description] cursor.close() return rv def generate_statement(conn, table, column_values, all_values): names = list(column_values.keys()) num_fields = random.randint(1, 20) query = 'SELECT ' for i in range(num_fields): if i > 0: query += ', ' query += 'quote(' + names[random.randint(0, len(names) - 1)] + ')' query += ' FROM ' query += table num_clauses = random.randint(0, 2) for i in range(num_clauses): if i == 0: query += ' WHERE ' else: query += ' AND ' field = names[random.randint(0, len(names) - 1)] op = ops[random.randint(0, len(ops) - 1)] values = column_values[field] if random.randint(0, 1) == 0: values = all_values value = values[random.randint(0, len(values) - 1)] query += field + ' ' + op + ' ' + str(value) return query def test_statement(conn, table, column_values, all_values): query = generate_statement(conn, table, column_values, all_values) gold = [row for row in conn.execute(query)] print('{} rows: {}'.format(len(gold), query)) for parquet in ['nulls1', 'nulls2', 'nulls3']: new_query = query.replace('nulls', parquet) rv = [row for row in conn.execute(new_query)] if gold != rv: with open('testcase-cmds.txt', 'w') as f: f.write('.load parquet/libparquet\n.testcase query\n.bail on\n{};\n.output\n'.format(new_query)) with open('testcase-expected.txt', 'w') as f: for row in gold: f.write('{}\n'.format(row)) with open('testcase-out.txt', 'w') as f: for row in rv: f.write('{}\n'.format(row)) raise ValueError('ruhroh') def test_table(conn, table): column_names = get_columns(conn, table) print('Table {}: {}'.format(table, column_names)) column_values = {} for name in column_names: column_values[name] = get_column_values(conn, table, name) #random.seed(0) all_values = [] for values in column_values.values(): all_values = all_values + values for i in range(1000): test_statement(conn, table, column_values, all_values) def test_db(db_file, extension_file, tables): conn = sqlite3.connect(db_file) conn.enable_load_extension(True) conn.load_extension(extension_file) conn.enable_load_extension(False) for table in tables: test_table(conn, table) if __name__ == '__main__': db_file = os.path.abspath(os.path.join(__file__, '..', '..', 'test.db')) extension_file = os.path.abspath(os.path.join(__file__, '..', '..', 'parquet', 'libparquet.so')) test_db(db_file, extension_file, ['nulls', 'no_nulls'])