sqlite-parquet-vtable/parquet-generator/parquets.py

154 lines
5.1 KiB
Python

from datetime import datetime, timedelta, timezone
import re
import pyarrow as pa
import pyarrow.parquet as pq
def make_99_rows():
'''Create 99 rows with unique values in each field, exercising all the main
physical types.'''
rows = []
for i in range(99):
ba_fixed = bytearray()
ba_fixed.append(i)
ba_variable = bytearray()
for j in range(1 + i % 5):
ba_variable.append(i)
row = []
# BOOLEAN, INT32, INT64, INT96, DOUBLE, BYTE_ARRAY, FLOAT
row.append(i % 2 == 0) # BOOLEAN
row.append(50 - i) # INT32/INT8
row.append(100 * (50 - i)) # INT32/INT16
row.append(1000 * 1000 * (50 - i)) # INT32/INT32
row.append(1000 * 1000 * 1000 * (50 - i)) # INT64/INT64
row.append(datetime(1985, 7, 20, tzinfo=timezone.utc) + timedelta(days=i)) # INT96
row.append(99.0 / (i + 1)) # DOUBLE
row.append(str(i)) # BYTE_ARRAY/UTF8
row.append('{:03}'.format(i)), # BYTE_ARRAY/UTF8
row.append(bytes(ba_variable)), # BYTE_ARRAY
row.append(bytes(ba_fixed)) # FIXED_LENGTH_BYTE_ARRAY
# pyarrow does not support float yet :(
# row.append(1.0 / (i + 1)) # FLOAT
rows.append(row)
return rows
def get_99_rows_types():
'''The types for the columns in `make_99_rows`.'''
return [
pa.bool_(),
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.timestamp('ns'),
pa.float64(),
pa.string(),
pa.string(),
pa.binary(-1),
pa.binary(1)
# pa.float32()
]
def write_parquet(file_name, rows, types, row_group_size):
'''Create two parquets with columns we support.'''
# pivot to be column major, create arrow structures
fields = []
for i in range(len(types)):
col = []
col.append([row[i] for row in rows])
fields.append(pa.chunked_array(col, type=types[i]))
def name_of(i):
name = '{}_{}'.format(types[i], i)
name = name.replace('timestamp[ns]', 'ts')
name = name.replace('fixed_size_binary[1]', 'binary')
return name
cols = [pa.Column.from_array(name_of(i), fields[i]) for i in range(len(fields))]
table = pa.Table.from_arrays(cols)
print('Writing {}'.format(file_name))
pq.write_table(table,
file_name,
row_group_size=row_group_size,
use_deprecated_int96_timestamps=True)
def write_unsupported_parquets():
# Taken from https://arrow.apache.org/docs/python/api.html
unsupported = [
pa.decimal128(10),
pa.null(),
pa.uint8(),
pa.uint16(),
# per https://issues.apache.org/jira/browse/ARROW-436, I think
# Parquet v1.0 can't serialize UINT32
#pa.uint32(),
pa.uint64(),
# pa.float16() <-- not supported by us, but also not by pyarrow
# TODO: list_, struct, dict
]
for type in unsupported:
file_name = 'unsupported-{}.parquet'.format(type)
file_name = re.sub(r'[^0-9a-z.-]', '-', file_name)
file_name = re.sub(r'--*', '-', file_name)
write_parquet(file_name, [], [type], row_group_size=1)
def write_csv(file_name, rows):
r'''Write a TSV that can be imported to Postgres.
Use "\N" for NULLs, tab literal for field separator.'''
print('Writing {}'.format(file_name))
with open(file_name, 'w') as f:
for rowid, row in enumerate(rows):
line = str(rowid)
for col in row:
line += '\t'
if col == True:
line += '1'
elif col == False:
line += '0'
elif col is None:
line += r'\N'
elif isinstance(col, bytes):
# Here we cheat and serialize a string that matches the output of
# quote(binary_field) in SQLite
entry = r"X'"
for b in col:
entry += '%0.2X' % b
entry += "'"
line += entry
elif isinstance(col, datetime):
line += str(1000 * int(col.timestamp()))
else:
line += str(col)
f.write(line + '\n')
def main():
'''Entrypoint.'''
rows = make_99_rows()
types = get_99_rows_types()
write_parquet('99-rows-1.parquet', rows, types, row_group_size=99)
write_csv('no-nulls.csv', rows)
write_parquet('99-rows-10.parquet', rows, types, row_group_size=10)
write_parquet('99-rows-99.parquet', rows, types, row_group_size=1)
for i in range(len(rows)):
for j in range(len(rows[i])):
if (i >= 10 and i <= 19) or (i >= 20 and (i + j) % 2 == 0):
rows[i][j] = None
write_parquet('99-rows-nulls-99.parquet', rows, types,row_group_size=99)
write_parquet('99-rows-nulls-10.parquet', rows, types,row_group_size=10)
write_parquet('99-rows-nulls-1.parquet', rows, types,row_group_size=1)
write_csv('nulls.csv', rows)
write_unsupported_parquets()
if __name__ == '__main__':
main()