mirror of
				https://github.com/cldellow/sqlite-parquet-vtable.git
				synced 2025-10-31 02:19:56 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			114 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/python3
 | |
| 
 | |
| import sqlite3
 | |
| import os.path
 | |
| import random
 | |
| 
 | |
| ops = ['=', '<>', '>', '>=', '<', '<=', 'IS', 'IS NOT', 'LIKE'];
 | |
| 
 | |
| def get_column_values(conn, table, column):
 | |
|     cursor = conn.execute('SELECT DISTINCT quote({}) FROM {}'.format(column, table))
 | |
|     rv = []
 | |
|     for row in cursor:
 | |
|         rv.append(row[0])
 | |
|     cursor.close()
 | |
|     return rv
 | |
| 
 | |
| def get_columns(conn, table):
 | |
|     cursor = conn.execute('SELECT * FROM {} LIMIT 1'.format(table))
 | |
|     rv = ['rowid'] + [description[0] for description in cursor.description]
 | |
|     cursor.close()
 | |
|     return rv
 | |
| 
 | |
| def generate_statement(conn, table, column_values, all_values):
 | |
|     names = list(column_values.keys())
 | |
| 
 | |
|     num_fields = random.randint(1, 20)
 | |
| 
 | |
|     query = 'SELECT '
 | |
|     for i in range(num_fields):
 | |
|         if i > 0:
 | |
|             query += ', '
 | |
| 
 | |
|         query += 'quote(' + names[random.randint(0, len(names) - 1)] + ')'
 | |
|     query += ' FROM '
 | |
|     query += table
 | |
| 
 | |
|     num_clauses = random.randint(0, 2)
 | |
|     for i in range(num_clauses):
 | |
|         if i == 0:
 | |
|             query += ' WHERE '
 | |
|         else:
 | |
|             query += ' AND '
 | |
| 
 | |
|         field = names[random.randint(0, len(names) - 1)]
 | |
|         op = ops[random.randint(0, len(ops) - 1)]
 | |
|         values = column_values[field]
 | |
|         if random.randint(0, 1) == 0:
 | |
|             values = all_values
 | |
|         value = values[random.randint(0, len(values) - 1)]
 | |
|         if random.randint(0, 5) == 0:
 | |
|             query += ' NOT '
 | |
|         query += '(' + field + ' ' + op + ' ' + str(value) + ')'
 | |
| 
 | |
|     if random.randint(0, 3) == 0:
 | |
|         how_many = random.randint(0, 15)
 | |
|         query += ' LIMIT {}'.format(how_many)
 | |
| 
 | |
|         if random.randint(0, 3) == 0:
 | |
|             how_many = random.randint(0, 30)
 | |
|             query += ' OFFSET {}'.format(how_many)
 | |
| 
 | |
|     return query
 | |
| 
 | |
| def test_statement(conn, table, column_values, all_values):
 | |
|     query = generate_statement(conn, table, column_values, all_values)
 | |
| 
 | |
|     gold = [row for row in conn.execute(query)]
 | |
|     print('{} rows: {}'.format(len(gold), query))
 | |
|     for parquet in ['nulls1', 'nulls2', 'nulls3']:
 | |
|         new_query = query.replace('nulls', parquet)
 | |
|         rv = [row for row in conn.execute(new_query)]
 | |
|         if gold != rv:
 | |
|             with open('testcase-cmds.txt', 'w') as f:
 | |
|                 f.write('.load build/linux/libparquet\n.testcase query\n.bail on\n{};\n.output\n'.format(new_query))
 | |
|             with open('testcase-expected.txt', 'w') as f:
 | |
|                 for row in gold:
 | |
|                     f.write('{}\n'.format(row))
 | |
|             with open('testcase-out.txt', 'w') as f:
 | |
|                 for row in rv:
 | |
|                     f.write('{}\n'.format(row))
 | |
| 
 | |
|             raise ValueError('ruhroh')
 | |
| 
 | |
| 
 | |
| 
 | |
| def test_table(conn, table):
 | |
|     # Don't include the floating point columns in random tests - sqlite itself stores doubles, so
 | |
|     # it can't act as an oracle for the FP stuff.
 | |
|     column_names = [x for x in get_columns(conn, table) if not x.startswith('float_')]
 | |
|     print('Table {}: {}'.format(table, column_names))
 | |
|     column_values = {}
 | |
|     for name in column_names:
 | |
|         column_values[name] = get_column_values(conn, table, name)
 | |
| 
 | |
|     #random.seed(0)
 | |
|     all_values = []
 | |
|     for values in column_values.values():
 | |
|         all_values = all_values + values
 | |
|     for i in range(1000):
 | |
|         test_statement(conn, table, column_values, all_values)
 | |
| 
 | |
| def test_db(db_file, extension_file, tables):
 | |
|     conn = sqlite3.connect(db_file)
 | |
|     conn.enable_load_extension(True)
 | |
|     conn.load_extension(extension_file)
 | |
|     conn.enable_load_extension(False)
 | |
|     for table in tables:
 | |
|         test_table(conn, table)
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     db_file = os.path.abspath(os.path.join(__file__, '..', '..', 'test.db'))
 | |
|     extension_file = os.path.abspath(os.path.join(__file__, '..', '..', 'build', 'linux', 'libparquet'))
 | |
|     test_db(db_file, extension_file, ['nulls', 'no_nulls'])
 | 
