mirror of
				https://github.com/cldellow/sqlite-parquet-vtable.git
				synced 2025-11-04 02:39:56 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			112 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			112 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/python3
 | 
						|
 | 
						|
import sqlite3
 | 
						|
import os.path
 | 
						|
import random
 | 
						|
 | 
						|
ops = ['=', '<>', '>', '>=', '<', '<=', 'IS', 'IS NOT', 'LIKE'];
 | 
						|
 | 
						|
def get_column_values(conn, table, column):
 | 
						|
    cursor = conn.execute('SELECT DISTINCT quote({}) FROM {}'.format(column, table))
 | 
						|
    rv = []
 | 
						|
    for row in cursor:
 | 
						|
        rv.append(row[0])
 | 
						|
    cursor.close()
 | 
						|
    return rv
 | 
						|
 | 
						|
def get_columns(conn, table):
 | 
						|
    cursor = conn.execute('SELECT * FROM {} LIMIT 1'.format(table))
 | 
						|
    rv = ['rowid'] + [description[0] for description in cursor.description]
 | 
						|
    cursor.close()
 | 
						|
    return rv
 | 
						|
 | 
						|
def generate_statement(conn, table, column_values, all_values):
 | 
						|
    names = list(column_values.keys())
 | 
						|
 | 
						|
    num_fields = random.randint(1, 20)
 | 
						|
 | 
						|
    query = 'SELECT '
 | 
						|
    for i in range(num_fields):
 | 
						|
        if i > 0:
 | 
						|
            query += ', '
 | 
						|
 | 
						|
        query += 'quote(' + names[random.randint(0, len(names) - 1)] + ')'
 | 
						|
    query += ' FROM '
 | 
						|
    query += table
 | 
						|
 | 
						|
    num_clauses = random.randint(0, 2)
 | 
						|
    for i in range(num_clauses):
 | 
						|
        if i == 0:
 | 
						|
            query += ' WHERE '
 | 
						|
        else:
 | 
						|
            query += ' AND '
 | 
						|
 | 
						|
        field = names[random.randint(0, len(names) - 1)]
 | 
						|
        op = ops[random.randint(0, len(ops) - 1)]
 | 
						|
        values = column_values[field]
 | 
						|
        if random.randint(0, 1) == 0:
 | 
						|
            values = all_values
 | 
						|
        value = values[random.randint(0, len(values) - 1)]
 | 
						|
        if random.randint(0, 5) == 0:
 | 
						|
            query += ' NOT '
 | 
						|
        query += '(' + field + ' ' + op + ' ' + str(value) + ')'
 | 
						|
 | 
						|
    if random.randint(0, 3) == 0:
 | 
						|
        how_many = random.randint(0, 15)
 | 
						|
        query += ' LIMIT {}'.format(how_many)
 | 
						|
 | 
						|
        if random.randint(0, 3) == 0:
 | 
						|
            how_many = random.randint(0, 30)
 | 
						|
            query += ' OFFSET {}'.format(how_many)
 | 
						|
 | 
						|
    return query
 | 
						|
 | 
						|
def test_statement(conn, table, column_values, all_values):
 | 
						|
    query = generate_statement(conn, table, column_values, all_values)
 | 
						|
 | 
						|
    gold = [row for row in conn.execute(query)]
 | 
						|
    print('{} rows: {}'.format(len(gold), query))
 | 
						|
    for parquet in ['nulls1', 'nulls2', 'nulls3']:
 | 
						|
        new_query = query.replace('nulls', parquet)
 | 
						|
        rv = [row for row in conn.execute(new_query)]
 | 
						|
        if gold != rv:
 | 
						|
            with open('testcase-cmds.txt', 'w') as f:
 | 
						|
                f.write('.load build/linux/libparquet\n.testcase query\n.bail on\n{};\n.output\n'.format(new_query))
 | 
						|
            with open('testcase-expected.txt', 'w') as f:
 | 
						|
                for row in gold:
 | 
						|
                    f.write('{}\n'.format(row))
 | 
						|
            with open('testcase-out.txt', 'w') as f:
 | 
						|
                for row in rv:
 | 
						|
                    f.write('{}\n'.format(row))
 | 
						|
 | 
						|
            raise ValueError('ruhroh')
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def test_table(conn, table):
 | 
						|
    column_names = get_columns(conn, table)
 | 
						|
    print('Table {}: {}'.format(table, column_names))
 | 
						|
    column_values = {}
 | 
						|
    for name in column_names:
 | 
						|
        column_values[name] = get_column_values(conn, table, name)
 | 
						|
 | 
						|
    #random.seed(0)
 | 
						|
    all_values = []
 | 
						|
    for values in column_values.values():
 | 
						|
        all_values = all_values + values
 | 
						|
    for i in range(1000):
 | 
						|
        test_statement(conn, table, column_values, all_values)
 | 
						|
 | 
						|
def test_db(db_file, extension_file, tables):
 | 
						|
    conn = sqlite3.connect(db_file)
 | 
						|
    conn.enable_load_extension(True)
 | 
						|
    conn.load_extension(extension_file)
 | 
						|
    conn.enable_load_extension(False)
 | 
						|
    for table in tables:
 | 
						|
        test_table(conn, table)
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    db_file = os.path.abspath(os.path.join(__file__, '..', '..', 'test.db'))
 | 
						|
    extension_file = os.path.abspath(os.path.join(__file__, '..', '..', 'build', 'linux', 'libparquet.so'))
 | 
						|
    test_db(db_file, extension_file, ['nulls', 'no_nulls'])
 |