#!/usr/bin/python3

import sqlite3
import os.path
import random

ops = ['=', '<>', '>', '>=', '<', '<=', 'IS', 'IS NOT', 'LIKE'];

def get_column_values(conn, table, column):
    cursor = conn.execute('SELECT DISTINCT quote({}) FROM {}'.format(column, table))
    rv = []
    for row in cursor:
        rv.append(row[0])
    cursor.close()
    return rv

def get_columns(conn, table):
    cursor = conn.execute('SELECT * FROM {} LIMIT 1'.format(table))
    rv = ['rowid'] + [description[0] for description in cursor.description]
    cursor.close()
    return rv

def generate_statement(conn, table, column_values, all_values):
    names = list(column_values.keys())

    num_fields = random.randint(1, 20)

    query = 'SELECT '
    for i in range(num_fields):
        if i > 0:
            query += ', '

        query += 'quote(' + names[random.randint(0, len(names) - 1)] + ')'
    query += ' FROM '
    query += table

    num_clauses = random.randint(0, 2)
    for i in range(num_clauses):
        if i == 0:
            query += ' WHERE '
        else:
            query += ' AND '

        field = names[random.randint(0, len(names) - 1)]
        op = ops[random.randint(0, len(ops) - 1)]
        values = column_values[field]
        if random.randint(0, 1) == 0:
            values = all_values
        value = values[random.randint(0, len(values) - 1)]
        query += field + ' ' + op + ' ' + str(value)

    if random.randint(0, 3) == 0:
        how_many = random.randint(0, 15)
        query += ' LIMIT {}'.format(how_many)

        if random.randint(0, 3) == 0:
            how_many = random.randint(0, 30)
            query += ' OFFSET {}'.format(how_many)

    return query

def test_statement(conn, table, column_values, all_values):
    query = generate_statement(conn, table, column_values, all_values)

    gold = [row for row in conn.execute(query)]
    print('{} rows: {}'.format(len(gold), query))
    for parquet in ['nulls1', 'nulls2', 'nulls3']:
        new_query = query.replace('nulls', parquet)
        rv = [row for row in conn.execute(new_query)]
        if gold != rv:
            with open('testcase-cmds.txt', 'w') as f:
                f.write('.load build/linux/libparquet\n.testcase query\n.bail on\n{};\n.output\n'.format(new_query))
            with open('testcase-expected.txt', 'w') as f:
                for row in gold:
                    f.write('{}\n'.format(row))
            with open('testcase-out.txt', 'w') as f:
                for row in rv:
                    f.write('{}\n'.format(row))

            raise ValueError('ruhroh')



def test_table(conn, table):
    column_names = get_columns(conn, table)
    print('Table {}: {}'.format(table, column_names))
    column_values = {}
    for name in column_names:
        column_values[name] = get_column_values(conn, table, name)

    #random.seed(0)
    all_values = []
    for values in column_values.values():
        all_values = all_values + values
    for i in range(1000):
        test_statement(conn, table, column_values, all_values)

def test_db(db_file, extension_file, tables):
    conn = sqlite3.connect(db_file)
    conn.enable_load_extension(True)
    conn.load_extension(extension_file)
    conn.enable_load_extension(False)
    for table in tables:
        test_table(conn, table)

if __name__ == '__main__':
    db_file = os.path.abspath(os.path.join(__file__, '..', '..', 'test.db'))
    extension_file = os.path.abspath(os.path.join(__file__, '..', '..', 'build', 'linux', 'libparquet.so'))
    test_db(db_file, extension_file, ['nulls', 'no_nulls'])