summaryrefslogtreecommitdiff
path: root/scripts/queue_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/queue_loader.py')
-rwxr-xr-xscripts/queue_loader.py531
1 files changed, 531 insertions, 0 deletions
diff --git a/scripts/queue_loader.py b/scripts/queue_loader.py
new file mode 100755
index 00000000..c14ecd01
--- /dev/null
+++ b/scripts/queue_loader.py
@@ -0,0 +1,531 @@
+#! /usr/bin/env python
+
+"""Load data from queue into tables, with optional partitioning."""
+
+import sys, time, skytools
+
+from pgq.cascade.worker import CascadedWorker
+
+from skytools import quote_ident, quote_fqident, UsageError
+
+# todo: auto table detect
+
+# BulkLoader load method
+METH_CORRECT = 0
+METH_DELETE = 1
+METH_MERGED = 2
+LOAD_METHOD = METH_CORRECT
+# BulkLoader hacks
+AVOID_BIZGRES_BUG = 0
+USE_LONGLIVED_TEMP_TABLES = True
+
+class BasicLoader:
+ """Apply events as-is."""
+ def __init__(self, table_name, parent_name, log):
+ self.table_name = table_name
+ self.parent_name = parent_name
+ self.sql_list = []
+ self.log = log
+
+ def add_row(self, op, data, pkey_list):
+ if op == 'I':
+ sql = skytools.mk_insert_sql(data, self.table_name, pkey_list)
+ elif op == 'U':
+ sql = skytools.mk_update_sql(data, self.table_name, pkey_list)
+ elif op == 'D':
+ sql = skytools.mk_delete_sql(data, self.table_name, pkey_list)
+ else:
+ raise Exception('bad operation: '+op)
+ self.sql_list.append(sql)
+
+ def flush(self, curs):
+ if len(self.sql_list) > 0:
+ curs.execute("\n".join(self.sql_list))
+ self.sql_list = []
+
+class KeepLatestLoader(BasicLoader):
+ """Keep latest row version.
+
+ Updates are changed to delete + insert, deletes are ignored.
+ Makes sense only for partitioned tables.
+ """
+ def add_row(self, op, data, pkey_list):
+ if op == 'U':
+ BasicLoader.add_row(self, 'D', data, pkey_list)
+ BasicLoader.add_row(self, 'I', data, pkey_list)
+ elif op == 'I':
+ BasicLoader.add_row(self, 'I', data, pkey_list)
+ else:
+ pass
+
+
+class KeepAllLoader(BasicLoader):
+ """Keep all row versions.
+
+ Updates are changed to inserts, deletes are ignored.
+ Makes sense only for partitioned tables.
+ """
+ def add_row(self, op, data, pkey_list):
+ if op == 'U':
+ op = 'I'
+ elif op == 'D':
+ return
+ BasicLoader.add_row(self, op, data, pkey_list)
+
+
+class BulkEvent(object):
+ """Helper class for BulkLoader to store relevant data."""
+ __slots__ = ('op', 'data', 'pk_data')
+ def __init__(self, op, data, pk_data):
+ self.op = op
+ self.data = data
+ self.pk_data = pk_data
+
+class BulkLoader(BasicLoader):
+ """Instead of statement-per event, load all data with one
+ big COPY, UPDATE or DELETE statement.
+ """
+ fake_seq = 0
+ def __init__(self, table_name, parent_name, log):
+ """Init per-batch table data cache."""
+ BasicLoader.__init__(self, table_name, parent_name, log)
+
+ self.pkey_list = None
+ self.dist_fields = None
+ self.col_list = None
+
+ self.ev_list = []
+ self.pkey_ev_map = {}
+
+ def reset(self):
+ self.ev_list = []
+ self.pkey_ev_map = {}
+
+ def add_row(self, op, data, pkey_list):
+ """Store new event."""
+
+ # get pkey value
+ if self.pkey_list is None:
+ self.pkey_list = pkey_list
+ if len(self.pkey_list) > 0:
+ pk_data = (data[k] for k in self.pkey_list)
+ elif op == 'I':
+ # fake pkey, just to get them spread out
+ pk_data = self.fake_seq
+ self.fake_seq += 1
+ else:
+ raise Exception('non-pk tables not supported: %s' % self.table_name)
+
+ # get full column list, detect added columns
+ if not self.col_list:
+ self.col_list = data.keys()
+ elif self.col_list != data.keys():
+ # ^ supposedly python guarantees same order in keys()
+ self.col_list = data.keys()
+
+ # add to list
+ ev = BulkEvent(op, data, pk_data)
+ self.ev_list.append(ev)
+
+ # keep all versions of row data
+ if ev.pk_data in self.pkey_ev_map:
+ self.pkey_ev_map[ev.pk_data].append(ev)
+ else:
+ self.pkey_ev_map[ev.pk_data] = [ev]
+
+ def prepare_data(self):
+ """Got all data, prepare for insertion."""
+
+ del_list = []
+ ins_list = []
+ upd_list = []
+ for ev_list in self.pkey_ev_map.values():
+ # rewrite list of I/U/D events to
+ # optional DELETE and optional INSERT/COPY command
+ exists_before = -1
+ exists_after = 1
+ for ev in ev_list:
+ if ev.op == "I":
+ if exists_before < 0:
+ exists_before = 0
+ exists_after = 1
+ elif ev.op == "U":
+ if exists_before < 0:
+ exists_before = 1
+ #exists_after = 1 # this shouldnt be needed
+ elif ev.op == "D":
+ if exists_before < 0:
+ exists_before = 1
+ exists_after = 0
+ else:
+ raise Exception('unknown event type: %s' % ev.op)
+
+ # skip short-lived rows
+ if exists_before == 0 and exists_after == 0:
+ continue
+
+ # take last event
+ ev = ev_list[-1]
+
+ # generate needed commands
+ if exists_before and exists_after:
+ upd_list.append(ev.data)
+ elif exists_before:
+ del_list.append(ev.data)
+ elif exists_after:
+ ins_list.append(ev.data)
+
+ return ins_list, upd_list, del_list
+
+ def flush(self, curs):
+ ins_list, upd_list, del_list = self.prepare_data()
+
+ # reorder cols
+ col_list = self.pkey_list[:]
+ for k in self.col_list:
+ if k not in self.pkey_list:
+ col_list.append(k)
+
+ real_update_count = len(upd_list)
+
+ #self.log.debug("process_one_table: %s (I/U/D = %d/%d/%d)" % (
+ # tbl, len(ins_list), len(upd_list), len(del_list)))
+
+ # hack to unbroke stuff
+ if LOAD_METHOD == METH_MERGED:
+ upd_list += ins_list
+ ins_list = []
+
+ # fetch distribution fields
+ if self.dist_fields is None:
+ self.dist_fields = self.find_dist_fields(curs)
+
+ key_fields = self.pkey_list[:]
+ for fld in self.dist_fields:
+ if fld not in key_fields:
+ key_fields.append(fld)
+ #self.log.debug("PKey fields: %s Extra fields: %s" % (
+ # ",".join(cache.pkey_list), ",".join(extra_fields)))
+
+ # create temp table
+ temp = self.create_temp_table(curs)
+ tbl = self.table_name
+
+ # where expr must have pkey and dist fields
+ klist = []
+ for pk in key_fields:
+ exp = "%s.%s = %s.%s" % (quote_fqident(tbl), quote_ident(pk),
+ quote_fqident(temp), quote_ident(pk))
+ klist.append(exp)
+ whe_expr = " and ".join(klist)
+
+ # create del sql
+ del_sql = "delete from only %s using %s where %s" % (
+ quote_fqident(tbl), quote_fqident(temp), whe_expr)
+
+ # create update sql
+ slist = []
+ for col in col_list:
+ if col not in key_fields:
+ exp = "%s = %s.%s" % (quote_ident(col), quote_fqident(temp), quote_ident(col))
+ slist.append(exp)
+ upd_sql = "update only %s set %s from %s where %s" % (
+ quote_fqident(tbl), ", ".join(slist), quote_fqident(temp), whe_expr)
+
+ # insert sql
+ colstr = ",".join([quote_ident(c) for c in col_list])
+ ins_sql = "insert into %s (%s) select %s from %s" % (
+ quote_fqident(tbl), colstr, colstr, quote_fqident(temp))
+
+ temp_used = False
+
+ # process deleted rows
+ if len(del_list) > 0:
+ #self.log.info("Deleting %d rows from %s" % (len(del_list), tbl))
+ # delete old rows
+ q = "truncate %s" % quote_fqident(temp)
+ self.log.debug(q)
+ curs.execute(q)
+ # copy rows
+ self.log.debug("COPY %d rows into %s" % (len(del_list), temp))
+ skytools.magic_insert(curs, temp, del_list, col_list)
+ # delete rows
+ self.log.debug(del_sql)
+ curs.execute(del_sql)
+ self.log.debug("%s - %d" % (curs.statusmessage, curs.rowcount))
+ self.log.debug(curs.statusmessage)
+ if len(del_list) != curs.rowcount:
+ self.log.warning("Delete mismatch: expected=%s updated=%d"
+ % (len(del_list), curs.rowcount))
+ temp_used = True
+
+ # process updated rows
+ if len(upd_list) > 0:
+ #self.log.info("Updating %d rows in %s" % (len(upd_list), tbl))
+ # delete old rows
+ q = "truncate %s" % quote_fqident(temp)
+ self.log.debug(q)
+ curs.execute(q)
+ # copy rows
+ self.log.debug("COPY %d rows into %s" % (len(upd_list), temp))
+ skytools.magic_insert(curs, temp, upd_list, col_list)
+ temp_used = True
+ if LOAD_METHOD == METH_CORRECT:
+ # update main table
+ self.log.debug(upd_sql)
+ curs.execute(upd_sql)
+ self.log.debug(curs.statusmessage)
+ # check count
+ if len(upd_list) != curs.rowcount:
+ self.log.warning("Update mismatch: expected=%s updated=%d"
+ % (len(upd_list), curs.rowcount))
+ else:
+ # delete from main table
+ self.log.debug(del_sql)
+ curs.execute(del_sql)
+ self.log.debug(curs.statusmessage)
+ # check count
+ if real_update_count != curs.rowcount:
+ self.log.warning("Update mismatch: expected=%s deleted=%d"
+ % (real_update_count, curs.rowcount))
+ # insert into main table
+ if AVOID_BIZGRES_BUG:
+ # copy again, into main table
+ self.log.debug("COPY %d rows into %s" % (len(upd_list), tbl))
+ skytools.magic_insert(curs, tbl, upd_list, col_list)
+ else:
+ # better way, but does not work due bizgres bug
+ self.log.debug(ins_sql)
+ curs.execute(ins_sql)
+ self.log.debug(curs.statusmessage)
+
+ # process new rows
+ if len(ins_list) > 0:
+ self.log.info("Inserting %d rows into %s" % (len(ins_list), tbl))
+ skytools.magic_insert(curs, tbl, ins_list, col_list)
+
+ # delete remaining rows
+ if temp_used:
+ if USE_LONGLIVED_TEMP_TABLES:
+ q = "truncate %s" % quote_fqident(temp)
+ else:
+ # fscking problems with long-lived temp tables
+ q = "drop table %s" % quote_fqident(temp)
+ self.log.debug(q)
+ curs.execute(q)
+
+ self.reset()
+
+ def create_temp_table(self, curs):
+ # create temp table for loading
+ tempname = self.table_name.replace('.', '_') + "_loadertmp"
+
+ # check if exists
+ if USE_LONGLIVED_TEMP_TABLES:
+ if skytools.exists_temp_table(curs, tempname):
+ self.log.debug("Using existing temp table %s" % tempname)
+ return tempname
+
+ # bizgres crashes on delete rows
+ arg = "on commit delete rows"
+ arg = "on commit preserve rows"
+ # create temp table for loading
+ q = "create temp table %s (like %s) %s" % (
+ quote_fqident(tempname), quote_fqident(self.table_name), arg)
+ self.log.debug("Creating temp table: %s" % q)
+ curs.execute(q)
+ return tempname
+
+ def find_dist_fields(self, curs):
+ if not skytools.exists_table(curs, "pg_catalog.mpp_distribution_policy"):
+ return []
+ schema, name = skytools.fq_name_parts(self.table_name)
+ q = "select a.attname"\
+ " from pg_class t, pg_namespace n, pg_attribute a,"\
+ " mpp_distribution_policy p"\
+ " where n.oid = t.relnamespace"\
+ " and p.localoid = t.oid"\
+ " and a.attrelid = t.oid"\
+ " and a.attnum = any(p.attrnums)"\
+ " and n.nspname = %s and t.relname = %s"
+ curs.execute(q, [schema, name])
+ res = []
+ for row in curs.fetchall():
+ res.append(row[0])
+ return res
+
+
+class TableHandler:
+ """Basic partitioned loader.
+ Splits events into partitions, if requested.
+ Then applies them without further processing.
+ """
+ def __init__(self, rowhandler, table_name, table_mode, cf, log):
+ self.part_map = {}
+ self.rowhandler = rowhandler
+ self.table_name = table_name
+ self.quoted_name = quote_fqident(table_name)
+ self.log = log
+ if table_mode == 'direct':
+ self.split = False
+ elif table_mode == 'split':
+ self.split = True
+ smode = cf.get('split_mode', 'by-batch-time')
+ sfield = None
+ if smode.find(':') > 0:
+ smode, sfield = smode.split(':', 1)
+ self.split_field = sfield
+ self.split_part = cf.get('split_part', '%(table_name)s_%(year)s_%(month)s_%(day)s')
+ self.split_part_template = cf.get('split_part_template', '')
+ if smode == 'by-batch-time':
+ self.split_format = self.split_date_from_batch
+ elif smode == 'by-event-time':
+ self.split_format = self.split_date_from_event
+ elif smode == 'by-date-field':
+ self.split_format = self.split_date_from_field
+ else:
+ raise UsageError('Bad value for split_mode: '+smode)
+ self.log.debug("%s: split_mode=%s, split_field=%s, split_part=%s" % (
+ self.table_name, smode, self.split_field, self.split_part))
+ elif table_mode == 'ignore':
+ pass
+ else:
+ raise UsageError('Bad value for table_mode: '+table_mode)
+
+ def split_date_from_batch(self, ev, data, batch_info):
+ d = batch_info['batch_end']
+ vals = {
+ 'table_name': self.table_name,
+ 'year': "%04d" % d.year,
+ 'month': "%02d" % d.month,
+ 'day': "%02d" % d.day,
+ 'hour': "%02d" % d.hour,
+ }
+ dst = self.split_part % vals
+ return dst
+
+ def split_date_from_event(self, ev, data, batch_info):
+ d = ev.ev_date
+ vals = {
+ 'table_name': self.table_name,
+ 'year': "%04d" % d.year,
+ 'month': "%02d" % d.month,
+ 'day': "%02d" % d.day,
+ 'hour': "%02d" % d.hour,
+ }
+ dst = self.split_part % vals
+ return dst
+
+ def split_date_from_field(self, ev, data, batch_info):
+ val = data[self.split_field]
+ date, time = val.split(' ', 1)
+ y, m, d = date.split('-')
+ h, rest = time.split(':', 1)
+ vals = {
+ 'table_name': self.table_name,
+ 'year': y,
+ 'month': m,
+ 'day': d,
+ 'hour': h,
+ }
+ dst = self.split_part % vals
+ return dst
+
+ def add(self, curs, ev, batch_info):
+ data = skytools.db_urldecode(ev.data)
+ op, pkeys = ev.type.split(':', 1)
+ pkey_list = pkeys.split(',')
+ if self.split:
+ dst = self.split_format(ev, data, batch_info)
+ if dst not in self.part_map:
+ self.check_part(curs, dst, pkey_list)
+ else:
+ dst = self.table_name
+
+ if dst not in self.part_map:
+ self.part_map[dst] = self.rowhandler(dst, self.table_name, self.log)
+
+ p = self.part_map[dst]
+ p.add_row(op, data, pkey_list)
+
+ def flush(self, curs):
+ for part in self.part_map.values():
+ part.flush(curs)
+
+ def check_part(self, curs, dst, pkey_list):
+ if skytools.exists_table(curs, dst):
+ return
+ if not self.split_part_template:
+ raise UsageError('Partition %s does not exist and split_part_template not specified' % dst)
+
+ vals = {
+ 'dest': quote_fqident(dst),
+ 'part': quote_fqident(dst),
+ 'parent': quote_fqident(self.table_name),
+ 'pkey': ",".join(pkey_list), # quoting?
+ }
+ sql = self.split_part_template % vals
+ curs.execute(sql)
+
+
+class IgnoreTable(TableHandler):
+ """Do-nothing."""
+ def add(self, curs, ev, batch_info):
+ pass
+
+
+class QueueLoader(CascadedWorker):
+ """Loader script."""
+ table_state = {}
+
+ def reset(self):
+ """Drop our caches on error."""
+ self.table_state = {}
+ CascadedWorker.reset(self)
+
+ def init_state(self, tbl):
+ cf = self.cf
+ if tbl in cf.cf.sections():
+ cf = cf.clone(tbl)
+ table_mode = cf.get('table_mode', 'ignore')
+ row_mode = cf.get('row_mode', 'plain')
+ if table_mode == 'ignore':
+ tblhandler = IgnoreTable
+ else:
+ tblhandler = TableHandler
+
+ if row_mode == 'plain':
+ rowhandler = BasicLoader
+ elif row_mode == 'keep_latest':
+ rowhandler = KeepLatestLoader
+ elif row_mode == 'keep_all':
+ rowhandler = KeepAllLoader
+ elif row_mode == 'bulk':
+ rowhandler = BulkLoader
+ else:
+ raise UsageError('Bad row_mode: '+row_mode)
+ self.table_state[tbl] = tblhandler(rowhandler, tbl, table_mode, cf, self.log)
+
+ def process_remote_event(self, src_curs, dst_curs, ev):
+ t = ev.type[:2]
+ if t not in ('I:', 'U:', 'D:'):
+ CascadedWorker.process_remote_event(self, src_curs, dst_curs, ev)
+ return
+
+ tbl = ev.extra1
+ if tbl not in self.table_state:
+ self.init_state(tbl)
+ st = self.table_state[tbl]
+ st.add(dst_curs, ev, self._batch_info)
+ ev.tag_done()
+
+ def finish_remote_batch(self, src_db, dst_db, tick_id):
+ curs = dst_db.cursor()
+ for st in self.table_state.values():
+ st.flush(curs)
+ CascadedWorker.finish_remote_batch(self, src_db, dst_db, tick_id)
+
+if __name__ == '__main__':
+ script = QueueLoader('queue_loader', 'db', sys.argv[1:])
+ script.start()
+