diff options
Diffstat (limited to 'python')
33 files changed, 5690 insertions, 0 deletions
diff --git a/python/conf/londiste.ini b/python/conf/londiste.ini new file mode 100644 index 00000000..a1506a32 --- /dev/null +++ b/python/conf/londiste.ini @@ -0,0 +1,16 @@ + +[londiste] +job_name = test_to_subcriber + +provider_db = dbname=provider port=6000 host=127.0.0.1 +subscriber_db = dbname=subscriber port=6000 host=127.0.0.1 + +# it will be used as sql ident so no dots/spaces +pgq_queue_name = londiste.replika + +logfile = ~/log/%(job_name)s.log +pidfile = ~/pid/%(job_name)s.pid + +# both events and ticks will be copied there +#mirror_queue = replika_mirror + diff --git a/python/conf/pgqadm.ini b/python/conf/pgqadm.ini new file mode 100644 index 00000000..a2e92f6b --- /dev/null +++ b/python/conf/pgqadm.ini @@ -0,0 +1,18 @@ + +[pgqadm] + +job_name = pgqadm_somedb + +db = dbname=provider port=6000 host=127.0.0.1 + +# how often to run maintenance [minutes] +maint_delay_min = 5 + +# how often to check for activity [secs] +loop_delay = 0.1 + +logfile = ~/log/%(job_name)s.log +pidfile = ~/pid/%(job_name)s.pid + +use_skylog = 0 + diff --git a/python/conf/skylog.ini b/python/conf/skylog.ini new file mode 100644 index 00000000..150ef934 --- /dev/null +++ b/python/conf/skylog.ini @@ -0,0 +1,76 @@ +; notes: +; - 'args' is mandatory in [handler_*] sections +; - in lists there must not be spaces + +; +; top-level config +; + +; list of all loggers +[loggers] +keys=root +; root logger sees everything. there can be per-job configs by +; specifing loggers with job_name of the script + +; list of all handlers +[handlers] +;; seems logger module immidiately initalized all handlers, +;; whether they are actually used or not. so better +;; keep this list in sync with actual handler list +;keys=stderr,logdb,logsrv,logfile +keys=stderr + +; list of all formatters +[formatters] +keys=short,long,none + +; +; map specific loggers to specifig handlers +; +[logger_root] +level=DEBUG +;handlers=stderr,logdb,logsrv,logfile +handlers=stderr + +; +; configure formatters +; +[formatter_short] +format=%(asctime)s %(levelname)s %(message)s +datefmt=%H:%M + +[formatter_long] +format=%(asctime)s %(process)s %(levelname)s %(message)s + +[formatter_none] +format=%(message)s + +; +; configure handlers +; + +; file. args: stream +[handler_stderr] +class=StreamHandler +args=(sys.stderr,) +formatter=short + +; log into db. args: conn_string +[handler_logdb] +class=skylog.LogDBHandler +args=("host=127.0.0.1 port=5432 user=logger dbname=logdb",) +formatter=none +level=INFO + +; JSON messages over UDP. args: host, port +[handler_logsrv] +class=skylog.UdpLogServerHandler +args=('127.0.0.1', 6666) +formatter=none + +; rotating logfile. args: filename, maxsize, maxcount +[handler_logfile] +class=skylog.EasyRotatingFileHandler +args=('~/log/%(job_name)s.log', 100*1024*1024, 3) +formatter=long + diff --git a/python/conf/wal-master.ini b/python/conf/wal-master.ini new file mode 100644 index 00000000..5ae8cb2b --- /dev/null +++ b/python/conf/wal-master.ini @@ -0,0 +1,18 @@ +[wal-master] +logfile = master.log +use_skylog = 0 + +master_db = dbname=template1 +master_data = /var/lib/postgresql/8.0/main +master_config = /etc/postgresql/8.0/main/postgresql.conf + + +slave = slave:/var/lib/postgresql/walshipping + +completed_wals = %(slave)s/logs.complete +partial_wals = %(slave)s/logs.partial +full_backup = %(slave)s/data.master + +# syncdaemon update frequency +loop_delay = 10.0 + diff --git a/python/conf/wal-slave.ini b/python/conf/wal-slave.ini new file mode 100644 index 00000000..912bf756 --- /dev/null +++ b/python/conf/wal-slave.ini @@ -0,0 +1,15 @@ +[wal-slave] +logfile = slave.log +use_skylog = 0 + +slave_data = /var/lib/postgresql/8.0/main +slave_stop_cmd = /etc/init.d/postgresql-8.0 stop +slave_start_cmd = /etc/init.d/postgresql-8.0 start + +slave = /var/lib/postgresql/walshipping +completed_wals = %(slave)s/logs.complete +partial_wals = %(slave)s/logs.partial +full_backup = %(slave)s/data.master + +keep_old_logs = 0 + diff --git a/python/londiste.py b/python/londiste.py new file mode 100755 index 00000000..9e5684ea --- /dev/null +++ b/python/londiste.py @@ -0,0 +1,130 @@ +#! /usr/bin/env python + +"""Londiste launcher. +""" + +import sys, os, optparse, skytools + +# python 2.3 will try londiste.py first... +import sys, os.path +if os.path.exists(os.path.join(sys.path[0], 'londiste.py')) \ + and not os.path.exists(os.path.join(sys.path[0], 'londiste')): + del sys.path[0] + +from londiste import * + +command_usage = """ +%prog [options] INI CMD [subcmd args] + +commands: + provider install installs modules, creates queue + provider add TBL ... add table to queue + provider remove TBL ... remove table from queue + provider tables show all tables linked to queue + provider seqs show all sequences on provider + + subscriber install installs schema + subscriber add TBL ... add table to subscriber + subscriber remove TBL ... remove table from subscriber + subscriber tables list tables subscriber has attached to + subscriber seqs list sequences subscriber is interested + subscriber missing list tables subscriber has not yet attached to + subscriber link QUE create mirror queue + subscriber unlink dont mirror queue + + replay replay events to subscriber + + switchover switch the roles between provider & subscriber + compare [TBL ...] compare table contents on both sides + repair [TBL ...] repair data on subscriber + copy full copy of table, internal cmd + subscriber check compare table structure on both sides + subscriber fkeys print out fkey drop/create commands + subscriber resync TBL ... do full copy again + subscriber register attaches subscriber to queue (also done by replay) + subscriber unregister detach subscriber from queue +""" + +"""switchover: +goal is to launch a replay with reverse config + +- should link be required? (link should guarantee all tables/seqs same?) +- should link auto-add tables for subscriber + +1. lock all tables on provider, in order specified by 'nr' +2. wait until old replay is past the point +3. sync seq +4. replace queue triggers on provider with deny triggers +5. replace deny triggers on subscriber with queue triggers +6. sync pgq tick seqs? change pgq config? + +""" + +class Londiste(skytools.DBScript): + def __init__(self, args): + skytools.DBScript.__init__(self, 'londiste', args) + + if self.options.rewind or self.options.reset: + self.script = Replicator(args) + return + + if len(self.args) < 2: + print "need command" + sys.exit(1) + cmd = self.args[1] + + if cmd =="provider": + script = ProviderSetup(args) + elif cmd == "subscriber": + script = SubscriberSetup(args) + elif cmd == "replay": + method = self.cf.get('method', 'direct') + if method == 'direct': + script = Replicator(args) + elif method == 'file_write': + script = FileWrite(args) + elif method == 'file_write': + script = FileWrite(args) + else: + print "unknown method, quitting" + sys.exit(1) + elif cmd == "copy": + script = CopyTable(args) + elif cmd == "compare": + script = Comparator(args) + elif cmd == "repair": + script = Repairer(args) + elif cmd == "upgrade": + script = UpgradeV2(args) + else: + print "Unknown command '%s', use --help for help" % cmd + sys.exit(1) + + self.script = script + + def start(self): + self.script.start() + + def init_optparse(self, parser=None): + p = skytools.DBScript.init_optparse(self, parser) + p.set_usage(command_usage.strip()) + + g = optparse.OptionGroup(p, "expert options") + g.add_option("--force", action="store_true", + help = "add: ignore table differences, repair: ignore lag") + g.add_option("--expect-sync", action="store_true", dest="expect_sync", + help = "add: no copy needed", default=False) + g.add_option("--skip-truncate", action="store_true", dest="skip_truncate", + help = "copy: keep old data", default=False) + g.add_option("--rewind", action="store_true", + help = "replay: sync queue pos with subscriber") + g.add_option("--reset", action="store_true", + help = "replay: forget queue pos on subscriber") + p.add_option_group(g) + + return p + +if __name__ == '__main__': + script = Londiste(sys.argv[1:]) + script.start() + diff --git a/python/londiste/__init__.py b/python/londiste/__init__.py new file mode 100644 index 00000000..97d67433 --- /dev/null +++ b/python/londiste/__init__.py @@ -0,0 +1,12 @@ + +"""Replication on top of PgQ.""" + +from playback import * +from compare import * +from file_read import * +from file_write import * +from setup import * +from table_copy import * +from installer import * +from repair import * + diff --git a/python/londiste/compare.py b/python/londiste/compare.py new file mode 100644 index 00000000..0029665b --- /dev/null +++ b/python/londiste/compare.py @@ -0,0 +1,45 @@ +#! /usr/bin/env python + +"""Compares tables in replication set. + +Currently just does count(1) on both sides. +""" + +import sys, os, time, skytools + +__all__ = ['Comparator'] + +from syncer import Syncer + +class Comparator(Syncer): + def process_sync(self, tbl, src_db, dst_db): + """Actual comparision.""" + + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + self.log.info('Counting %s' % tbl) + + q = "select count(1) from only _TABLE_" + q = self.cf.get('compare_sql', q) + q = q.replace('_TABLE_', tbl) + + self.log.debug("srcdb: " + q) + src_curs.execute(q) + src_row = src_curs.fetchone() + src_str = ", ".join(map(str, src_row)) + self.log.info("srcdb: res = %s" % src_str) + + self.log.debug("dstdb: " + q) + dst_curs.execute(q) + dst_row = dst_curs.fetchone() + dst_str = ", ".join(map(str, dst_row)) + self.log.info("dstdb: res = %s" % dst_str) + + if src_str != dst_str: + self.log.warning("%s: Results do not match!" % tbl) + +if __name__ == '__main__': + script = Comparator(sys.argv[1:]) + script.start() + diff --git a/python/londiste/file_read.py b/python/londiste/file_read.py new file mode 100644 index 00000000..2902bda5 --- /dev/null +++ b/python/londiste/file_read.py @@ -0,0 +1,52 @@ + +"""Reads events from file instead of db queue.""" + +import sys, os, re, skytools + +from playback import * +from table_copy import * + +__all__ = ['FileRead'] + +file_regex = r"^tick_0*([0-9]+)\.sql$" +file_rc = re.compile(file_regex) + + +class FileRead(CopyTable): + """Reads events from file instead of db queue. + + Incomplete implementation. + """ + + def __init__(self, args, log = None): + CopyTable.__init__(self, args, log, copy_thread = 0) + + def launch_copy(self, tbl): + # copy immidiately + self.do_copy(t) + + def work(self): + last_batch = self.get_last_batch(curs) + list = self.get_file_list() + + def get_list(self): + """Return list of (first_batch, full_filename) pairs.""" + + src_dir = self.cf.get('file_src') + list = os.listdir(src_dir) + list.sort() + res = [] + for fn in list: + m = file_rc.match(fn) + if not m: + self.log.debug("Ignoring file: %s" % fn) + continue + full = os.path.join(src_dir, fn) + batch_id = int(m.group(1)) + res.append((batch_id, full)) + return res + +if __name__ == '__main__': + script = Replicator(sys.argv[1:]) + script.start() + diff --git a/python/londiste/file_write.py b/python/londiste/file_write.py new file mode 100644 index 00000000..86e16aae --- /dev/null +++ b/python/londiste/file_write.py @@ -0,0 +1,67 @@ + +"""Writes events into file.""" + +import sys, os, skytools +from cStringIO import StringIO +from playback import * + +__all__ = ['FileWrite'] + +class FileWrite(Replicator): + """Writes events into file. + + Incomplete implementation. + """ + + last_successful_batch = None + + def load_state(self, batch_id): + # maybe check if batch exists on filesystem? + self.cur_tick = self.cur_batch_info['tick_id'] + self.prev_tick = self.cur_batch_info['prev_tick_id'] + return 1 + + def process_batch(self, db, batch_id, ev_list): + pass + + def save_state(self, do_commit): + # nothing to save + pass + + def sync_tables(self, dst_db): + # nothing to sync + return 1 + + def interesting(self, ev): + # wants all of them + return 1 + + def handle_data_event(self, ev): + fmt = self.sql_command[ev.type] + sql = fmt % (ev.ev_extra1, ev.data) + row = "%s -- txid:%d" % (sql, ev.txid) + self.sql_list.append(row) + ev.tag_done() + + def handle_system_event(self, ev): + row = "-- sysevent:%s txid:%d data:%s" % ( + ev.type, ev.txid, ev.data) + self.sql_list.append(row) + ev.tag_done() + + def flush_sql(self): + self.sql_list.insert(0, "-- tick:%d prev:%s" % ( + self.cur_tick, self.prev_tick)) + self.sql_list.append("-- end_tick:%d\n" % self.cur_tick) + # store result + dir = self.cf.get("file_dst") + fn = os.path.join(dir, "tick_%010d.sql" % self.cur_tick) + f = open(fn, "w") + buf = "\n".join(self.sql_list) + f.write(buf) + f.close() + +if __name__ == '__main__': + script = Replicator(sys.argv[1:]) + script.start() + diff --git a/python/londiste/installer.py b/python/londiste/installer.py new file mode 100644 index 00000000..6f190ab2 --- /dev/null +++ b/python/londiste/installer.py @@ -0,0 +1,26 @@ + +"""Functions to install londiste and its depentencies into database.""" + +import os, skytools + +__all__ = ['install_provider', 'install_subscriber'] + +provider_object_list = [ + skytools.DBFunction('logtriga', 0, sql_file = "logtriga.sql"), + skytools.DBFunction('get_current_snapshot', 0, sql_file = "txid.sql"), + skytools.DBSchema('pgq', sql_file = "pgq.sql"), + skytools.DBSchema('londiste', sql_file = "londiste.sql") +] + +subscriber_object_list = [ + skytools.DBSchema('londiste', sql_file = "londiste.sql") +] + +def install_provider(curs, log): + """Installs needed code into provider db.""" + skytools.db_install(curs, provider_object_list, log) + +def install_subscriber(curs, log): + """Installs needed code into subscriber db.""" + skytools.db_install(curs, subscriber_object_list, log) + diff --git a/python/londiste/playback.py b/python/londiste/playback.py new file mode 100644 index 00000000..2bcb1bc7 --- /dev/null +++ b/python/londiste/playback.py @@ -0,0 +1,558 @@ +#! /usr/bin/env python + +"""Basic replication core.""" + +import sys, os, time +import skytools, pgq + +__all__ = ['Replicator', 'TableState', + 'TABLE_MISSING', 'TABLE_IN_COPY', 'TABLE_CATCHING_UP', + 'TABLE_WANNA_SYNC', 'TABLE_DO_SYNC', 'TABLE_OK'] + +# state # owner - who is allowed to change +TABLE_MISSING = 0 # main +TABLE_IN_COPY = 1 # copy +TABLE_CATCHING_UP = 2 # copy +TABLE_WANNA_SYNC = 3 # main +TABLE_DO_SYNC = 4 # copy +TABLE_OK = 5 # setup + +SYNC_OK = 0 # continue with batch +SYNC_LOOP = 1 # sleep, try again +SYNC_EXIT = 2 # nothing to do, exit skript + +class Counter(object): + """Counts table statuses.""" + + missing = 0 + copy = 0 + catching_up = 0 + wanna_sync = 0 + do_sync = 0 + ok = 0 + + def __init__(self, tables): + """Counts and sanity checks.""" + for t in tables: + if t.state == TABLE_MISSING: + self.missing += 1 + elif t.state == TABLE_IN_COPY: + self.copy += 1 + elif t.state == TABLE_CATCHING_UP: + self.catching_up += 1 + elif t.state == TABLE_WANNA_SYNC: + self.wanna_sync += 1 + elif t.state == TABLE_DO_SYNC: + self.do_sync += 1 + elif t.state == TABLE_OK: + self.ok += 1 + # only one table is allowed to have in-progress copy + if self.copy + self.catching_up + self.wanna_sync + self.do_sync > 1: + raise Exception('Bad table state') + +class TableState(object): + """Keeps state about one table.""" + def __init__(self, name, log): + self.name = name + self.log = log + self.forget() + self.changed = 0 + + def forget(self): + self.state = TABLE_MISSING + self.str_snapshot = None + self.from_snapshot = None + self.sync_tick_id = None + self.ok_batch_count = 0 + self.last_tick = 0 + self.changed = 1 + + def change_snapshot(self, str_snapshot, tag_changed = 1): + if self.str_snapshot == str_snapshot: + return + self.log.debug("%s: change_snapshot to %s" % (self.name, str_snapshot)) + self.str_snapshot = str_snapshot + if str_snapshot: + self.from_snapshot = skytools.Snapshot(str_snapshot) + else: + self.from_snapshot = None + + if tag_changed: + self.ok_batch_count = 0 + self.last_tick = None + self.changed = 1 + + def change_state(self, state, tick_id = None): + if self.state == state and self.sync_tick_id == tick_id: + return + self.state = state + self.sync_tick_id = tick_id + self.changed = 1 + self.log.debug("%s: change_state to %s" % (self.name, + self.render_state())) + + def render_state(self): + """Make a string to be stored in db.""" + + if self.state == TABLE_MISSING: + return None + elif self.state == TABLE_IN_COPY: + return 'in-copy' + elif self.state == TABLE_CATCHING_UP: + return 'catching-up' + elif self.state == TABLE_WANNA_SYNC: + return 'wanna-sync:%d' % self.sync_tick_id + elif self.state == TABLE_DO_SYNC: + return 'do-sync:%d' % self.sync_tick_id + elif self.state == TABLE_OK: + return 'ok' + + def parse_state(self, merge_state): + """Read state from string.""" + + state = -1 + if merge_state == None: + state = TABLE_MISSING + elif merge_state == "in-copy": + state = TABLE_IN_COPY + elif merge_state == "catching-up": + state = TABLE_CATCHING_UP + elif merge_state == "ok": + state = TABLE_OK + elif merge_state == "?": + state = TABLE_OK + else: + tmp = merge_state.split(':') + if len(tmp) == 2: + self.sync_tick_id = int(tmp[1]) + if tmp[0] == 'wanna-sync': + state = TABLE_WANNA_SYNC + elif tmp[0] == 'do-sync': + state = TABLE_DO_SYNC + + if state < 0: + raise Exception("Bad table state: %s" % merge_state) + + return state + + def loaded_state(self, merge_state, str_snapshot): + self.log.debug("loaded_state: %s: %s / %s" % ( + self.name, merge_state, str_snapshot)) + self.change_snapshot(str_snapshot, 0) + self.state = self.parse_state(merge_state) + self.changed = 0 + if merge_state == "?": + self.changed = 1 + + def interesting(self, ev, tick_id, copy_thread): + """Check if table wants this event.""" + + if copy_thread: + if self.state not in (TABLE_CATCHING_UP, TABLE_DO_SYNC): + return False + else: + if self.state != TABLE_OK: + return False + + # if no snapshot tracking, then accept always + if not self.from_snapshot: + return True + + # uninteresting? + if self.from_snapshot.contains(ev.txid): + return False + + # after couple interesting batches there no need to check snapshot + # as there can be only one partially interesting batch + if tick_id != self.last_tick: + self.last_tick = tick_id + self.ok_batch_count += 1 + + # disable batch tracking + if self.ok_batch_count > 3: + self.change_snapshot(None) + return True + +class SeqCache(object): + def __init__(self): + self.seq_list = [] + self.val_cache = {} + + def set_seq_list(self, seq_list): + self.seq_list = seq_list + new_cache = {} + for seq in seq_list: + val = self.val_cache.get(seq) + if val: + new_cache[seq] = val + self.val_cache = new_cache + + def resync(self, src_curs, dst_curs): + if len(self.seq_list) == 0: + return + dat = ".last_value, ".join(self.seq_list) + dat += ".last_value" + q = "select %s from %s" % (dat, ",".join(self.seq_list)) + src_curs.execute(q) + row = src_curs.fetchone() + for i in range(len(self.seq_list)): + seq = self.seq_list[i] + cur = row[i] + old = self.val_cache.get(seq) + if old != cur: + q = "select setval(%s, %s)" + dst_curs.execute(q, [seq, cur]) + self.val_cache[seq] = cur + +class Replicator(pgq.SerialConsumer): + """Replication core.""" + + sql_command = { + 'I': "insert into %s %s;", + 'U': "update only %s set %s;", + 'D': "delete from only %s where %s;", + } + + # batch info + cur_tick = 0 + prev_tick = 0 + + def __init__(self, args): + pgq.SerialConsumer.__init__(self, 'londiste', 'provider_db', 'subscriber_db', args) + + # tick table in dst for SerialConsumer(). keep londiste stuff under one schema + self.dst_completed_table = "londiste.completed" + + self.table_list = [] + self.table_map = {} + + self.copy_thread = 0 + self.maint_time = 0 + self.seq_cache = SeqCache() + self.maint_delay = self.cf.getint('maint_delay', 600) + self.mirror_queue = self.cf.get('mirror_queue', '') + + def process_remote_batch(self, src_db, batch_id, ev_list, dst_db): + "All work for a batch. Entry point from SerialConsumer." + + # this part can play freely with transactions + + dst_curs = dst_db.cursor() + + self.cur_tick = self.cur_batch_info['tick_id'] + self.prev_tick = self.cur_batch_info['prev_tick_id'] + + self.load_table_state(dst_curs) + self.sync_tables(dst_db) + + # now the actual event processing happens. + # they must be done all in one tx in dst side + # and the transaction must be kept open so that + # the SerialConsumer can save last tick and commit. + + self.handle_seqs(dst_curs) + self.handle_events(dst_curs, ev_list) + self.save_table_state(dst_curs) + + def handle_seqs(self, dst_curs): + if self.copy_thread: + return + + q = "select * from londiste.subscriber_get_seq_list(%s)" + dst_curs.execute(q, [self.pgq_queue_name]) + seq_list = [] + for row in dst_curs.fetchall(): + seq_list.append(row[0]) + + self.seq_cache.set_seq_list(seq_list) + + src_curs = self.get_database('provider_db').cursor() + self.seq_cache.resync(src_curs, dst_curs) + + def sync_tables(self, dst_db): + """Table sync loop. + + Calls appropriate handles, which is expected to + return one of SYNC_* constants.""" + + self.log.debug('Sync tables') + while 1: + cnt = Counter(self.table_list) + if self.copy_thread: + res = self.sync_from_copy_thread(cnt, dst_db) + else: + res = self.sync_from_main_thread(cnt, dst_db) + + if res == SYNC_EXIT: + self.log.debug('Sync tables: exit') + self.detach() + sys.exit(0) + elif res == SYNC_OK: + return + elif res != SYNC_LOOP: + raise Exception('Program error') + + self.log.debug('Sync tables: sleeping') + time.sleep(3) + dst_db.commit() + self.load_table_state(dst_db.cursor()) + dst_db.commit() + + def sync_from_main_thread(self, cnt, dst_db): + "Main thread sync logic." + + # + # decide what to do - order is imortant + # + if cnt.do_sync: + # wait for copy thread to catch up + return SYNC_LOOP + elif cnt.wanna_sync: + # copy thread wants sync, if not behind, do it + t = self.get_table_by_state(TABLE_WANNA_SYNC) + if self.cur_tick >= t.sync_tick_id: + self.change_table_state(dst_db, t, TABLE_DO_SYNC, self.cur_tick) + return SYNC_LOOP + else: + return SYNC_OK + elif cnt.catching_up: + # active copy, dont worry + return SYNC_OK + elif cnt.copy: + # active copy, dont worry + return SYNC_OK + elif cnt.missing: + # seems there is no active copy thread, launch new + t = self.get_table_by_state(TABLE_MISSING) + self.change_table_state(dst_db, t, TABLE_IN_COPY) + + # the copy _may_ happen immidiately + self.launch_copy(t) + + # there cannot be interesting events in current batch + # but maybe there's several tables, lets do them in one go + return SYNC_LOOP + else: + # seems everything is in sync + return SYNC_OK + + def sync_from_copy_thread(self, cnt, dst_db): + "Copy thread sync logic." + + # + # decide what to do - order is imortant + # + if cnt.do_sync: + # main thread is waiting, catch up, then handle over + t = self.get_table_by_state(TABLE_DO_SYNC) + if self.cur_tick == t.sync_tick_id: + self.change_table_state(dst_db, t, TABLE_OK) + return SYNC_EXIT + elif self.cur_tick < t.sync_tick_id: + return SYNC_OK + else: + self.log.error("copy_sync: cur_tick=%d sync_tick=%d" % ( + self.cur_tick, t.sync_tick_id)) + raise Exception('Invalid table state') + elif cnt.wanna_sync: + # wait for main thread to react + return SYNC_LOOP + elif cnt.catching_up: + # is there more work? + if self.work_state: + return SYNC_OK + + # seems we have catched up + t = self.get_table_by_state(TABLE_CATCHING_UP) + self.change_table_state(dst_db, t, TABLE_WANNA_SYNC, self.cur_tick) + return SYNC_LOOP + elif cnt.copy: + # table is not copied yet, do it + t = self.get_table_by_state(TABLE_IN_COPY) + self.do_copy(t) + + # forget previous value + self.work_state = 1 + + return SYNC_LOOP + else: + # nothing to do + return SYNC_EXIT + + def handle_events(self, dst_curs, ev_list): + "Actual event processing happens here." + + ignored_events = 0 + self.sql_list = [] + mirror_list = [] + for ev in ev_list: + if not self.interesting(ev): + ignored_events += 1 + ev.tag_done() + continue + + if ev.type in ('I', 'U', 'D'): + self.handle_data_event(ev, dst_curs) + else: + self.handle_system_event(ev, dst_curs) + + if self.mirror_queue: + mirror_list.append(ev) + + # finalize table changes + self.flush_sql(dst_curs) + self.stat_add('ignored', ignored_events) + + # put events into mirror queue if requested + if self.mirror_queue: + self.fill_mirror_queue(mirror_list, dst_curs) + + def handle_data_event(self, ev, dst_curs): + fmt = self.sql_command[ev.type] + sql = fmt % (ev.extra1, ev.data) + self.sql_list.append(sql) + if len(self.sql_list) > 200: + self.flush_sql(dst_curs) + ev.tag_done() + + def flush_sql(self, dst_curs): + if len(self.sql_list) == 0: + return + + buf = "\n".join(self.sql_list) + self.sql_list = [] + + dst_curs.execute(buf) + + def interesting(self, ev): + if ev.type not in ('I', 'U', 'D'): + return 1 + t = self.get_table_by_name(ev.extra1) + if t: + return t.interesting(ev, self.cur_tick, self.copy_thread) + else: + return 0 + + def handle_system_event(self, ev, dst_curs): + "System event." + + if ev.type == "T": + self.log.info("got new table event: "+ev.data) + # check tables to be dropped + name_list = [] + for name in ev.data.split(','): + name_list.append(name.strip()) + + del_list = [] + for tbl in self.table_list: + if tbl.name in name_list: + continue + del_list.append(tbl) + + # separate loop to avoid changing while iterating + for tbl in del_list: + self.log.info("Removing table %s from set" % tbl.name) + self.remove_table(tbl, dst_curs) + + ev.tag_done() + else: + self.log.warning("Unknows op %s" % ev.type) + ev.tag_failed("Unknown operation") + + def remove_table(self, tbl, dst_curs): + del self.table_map[tbl.name] + self.table_list.remove(tbl) + q = "select londiste.subscriber_remove_table(%s, %s)" + dst_curs.execute(q, [self.pgq_queue_name, tbl.name]) + + def load_table_state(self, curs): + """Load table state from database. + + Todo: if all tables are OK, there is no need + to load state on every batch. + """ + + q = """select table_name, snapshot, merge_state + from londiste.subscriber_get_table_list(%s) + """ + curs.execute(q, [self.pgq_queue_name]) + + new_list = [] + new_map = {} + for row in curs.dictfetchall(): + t = self.get_table_by_name(row['table_name']) + if not t: + t = TableState(row['table_name'], self.log) + t.loaded_state(row['merge_state'], row['snapshot']) + new_list.append(t) + new_map[t.name] = t + + self.table_list = new_list + self.table_map = new_map + + def save_table_state(self, curs): + """Store changed table state in database.""" + + for t in self.table_list: + if not t.changed: + continue + merge_state = t.render_state() + self.log.info("storing state of %s: copy:%d new_state:%s" % ( + t.name, self.copy_thread, merge_state)) + q = "select londiste.subscriber_set_table_state(%s, %s, %s, %s)" + curs.execute(q, [self.pgq_queue_name, + t.name, t.str_snapshot, merge_state]) + t.changed = 0 + + def change_table_state(self, dst_db, tbl, state, tick_id = None): + tbl.change_state(state, tick_id) + self.save_table_state(dst_db.cursor()) + dst_db.commit() + + self.log.info("Table %s status changed to '%s'" % ( + tbl.name, tbl.render_state())) + + def get_table_by_state(self, state): + "get first table with specific state" + + for t in self.table_list: + if t.state == state: + return t + raise Exception('No table was found with state: %d' % state) + + def get_table_by_name(self, name): + if name.find('.') < 0: + name = "public.%s" % name + if name in self.table_map: + return self.table_map[name] + return None + + def fill_mirror_queue(self, ev_list, dst_curs): + # insert events + rows = [] + fields = ['ev_type', 'ev_data', 'ev_extra1'] + for ev in mirror_list: + rows.append((ev.type, ev.data, ev.extra1)) + pgq.bulk_insert_events(dst_curs, rows, fields, self.mirror_queue) + + # create tick + q = "select pgq.ticker(%s, %s)" + dst_curs.execute(q, [self.mirror_queue, self.cur_tick]) + + def launch_copy(self, tbl_stat): + self.log.info("Launching copy process") + script = sys.argv[0] + conf = self.cf.filename + if self.options.verbose: + cmd = "%s -d -v %s copy" + else: + cmd = "%s -d %s copy" + cmd = cmd % (script, conf) + self.log.debug("Launch args: "+repr(cmd)) + res = os.system(cmd) + self.log.debug("Launch result: "+repr(res)) + +if __name__ == '__main__': + script = Replicator(sys.argv[1:]) + script.start() + diff --git a/python/londiste/repair.py b/python/londiste/repair.py new file mode 100644 index 00000000..ec4bd404 --- /dev/null +++ b/python/londiste/repair.py @@ -0,0 +1,284 @@ + +"""Repair data on subscriber. + +Walks tables by primary key and searcher +missing inserts/updates/deletes. +""" + +import sys, os, time, psycopg, skytools + +from syncer import Syncer + +__all__ = ['Repairer'] + +def unescape(s): + return skytools.unescape_copy(s) + +def get_pkey_list(curs, tbl): + """Get list of pkey fields in right order.""" + + oid = skytools.get_table_oid(curs, tbl) + q = """SELECT k.attname FROM pg_index i, pg_attribute k + WHERE i.indrelid = %s AND k.attrelid = i.indexrelid + AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped + ORDER BY k.attnum""" + curs.execute(q, [oid]) + list = [] + for row in curs.fetchall(): + list.append(row[0]) + return list + +def get_column_list(curs, tbl): + """Get list of columns in right order.""" + + oid = skytools.get_table_oid(curs, tbl) + q = """SELECT a.attname FROM pg_attribute a + WHERE a.attrelid = %s + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum""" + curs.execute(q, [oid]) + list = [] + for row in curs.fetchall(): + list.append(row[0]) + return list + +class Repairer(Syncer): + """Walks tables in primary key order and checks if data matches.""" + + + def process_sync(self, tbl, src_db, dst_db): + """Actual comparision.""" + + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + self.log.info('Checking %s' % tbl) + + self.common_fields = [] + self.pkey_list = [] + copy_tbl = self.gen_copy_tbl(tbl, src_curs, dst_curs) + + dump_src = tbl + ".src" + dump_dst = tbl + ".dst" + + self.log.info("Dumping src table: %s" % tbl) + self.dump_table(tbl, copy_tbl, src_curs, dump_src) + src_db.commit() + self.log.info("Dumping dst table: %s" % tbl) + self.dump_table(tbl, copy_tbl, dst_curs, dump_dst) + dst_db.commit() + + self.log.info("Sorting src table: %s" % tbl) + + s_in, s_out = os.popen4("sort --version") + s_ver = s_out.read() + del s_in, s_out + if s_ver.find("coreutils") > 0: + args = "-S 30%" + else: + args = "" + os.system("sort %s -T . -o %s.sorted %s" % (args, dump_src, dump_src)) + self.log.info("Sorting dst table: %s" % tbl) + os.system("sort %s -T . -o %s.sorted %s" % (args, dump_dst, dump_dst)) + + self.dump_compare(tbl, dump_src + ".sorted", dump_dst + ".sorted") + + os.unlink(dump_src) + os.unlink(dump_dst) + os.unlink(dump_src + ".sorted") + os.unlink(dump_dst + ".sorted") + + def gen_copy_tbl(self, tbl, src_curs, dst_curs): + self.pkey_list = get_pkey_list(src_curs, tbl) + dst_pkey = get_pkey_list(dst_curs, tbl) + if dst_pkey != self.pkey_list: + self.log.error('pkeys do not match') + sys.exit(1) + + src_cols = get_column_list(src_curs, tbl) + dst_cols = get_column_list(dst_curs, tbl) + field_list = [] + for f in self.pkey_list: + field_list.append(f) + for f in src_cols: + if f in self.pkey_list: + continue + if f in dst_cols: + field_list.append(f) + + self.common_fields = field_list + + tbl_expr = "%s (%s)" % (tbl, ",".join(field_list)) + + self.log.debug("using copy expr: %s" % tbl_expr) + + return tbl_expr + + def dump_table(self, tbl, copy_tbl, curs, fn): + f = open(fn, "w", 64*1024) + curs.copy_to(f, copy_tbl) + size = f.tell() + f.close() + self.log.info('Got %d bytes' % size) + + def get_row(self, ln): + t = ln[:-1].split('\t') + row = {} + for i in range(len(self.common_fields)): + row[self.common_fields[i]] = t[i] + return row + + def dump_compare(self, tbl, src_fn, dst_fn): + self.log.info("Comparing dumps: %s" % tbl) + self.cnt_insert = 0 + self.cnt_update = 0 + self.cnt_delete = 0 + self.total_src = 0 + self.total_dst = 0 + f1 = open(src_fn, "r", 64*1024) + f2 = open(dst_fn, "r", 64*1024) + src_ln = f1.readline() + dst_ln = f2.readline() + if src_ln: self.total_src += 1 + if dst_ln: self.total_dst += 1 + + fix = "fix.%s.sql" % tbl + if os.path.isfile(fix): + os.unlink(fix) + + while src_ln or dst_ln: + keep_src = keep_dst = 0 + if src_ln != dst_ln: + src_row = self.get_row(src_ln) + dst_row = self.get_row(dst_ln) + + cmp = self.cmp_keys(src_row, dst_row) + if cmp > 0: + # src > dst + self.got_missed_delete(tbl, dst_row) + keep_src = 1 + elif cmp < 0: + # src < dst + self.got_missed_insert(tbl, src_row) + keep_dst = 1 + else: + if self.cmp_data(src_row, dst_row) != 0: + self.got_missed_update(tbl, src_row, dst_row) + + if not keep_src: + src_ln = f1.readline() + if src_ln: self.total_src += 1 + if not keep_dst: + dst_ln = f2.readline() + if dst_ln: self.total_dst += 1 + + self.log.info("finished %s: src: %d rows, dst: %d rows,"\ + " missed: %d inserts, %d updates, %d deletes" % ( + tbl, self.total_src, self.total_dst, + self.cnt_insert, self.cnt_update, self.cnt_delete)) + + def got_missed_insert(self, tbl, src_row): + self.cnt_insert += 1 + fld_list = self.common_fields + val_list = [] + for f in fld_list: + v = unescape(src_row[f]) + val_list.append(skytools.quote_literal(v)) + q = "insert into %s (%s) values (%s);" % ( + tbl, ", ".join(fld_list), ", ".join(val_list)) + self.show_fix(tbl, q, 'insert') + + def got_missed_update(self, tbl, src_row, dst_row): + self.cnt_update += 1 + fld_list = self.common_fields + set_list = [] + whe_list = [] + for f in self.pkey_list: + self.addcmp(whe_list, f, unescape(src_row[f])) + for f in fld_list: + v1 = src_row[f] + v2 = dst_row[f] + if self.cmp_value(v1, v2) == 0: + continue + + self.addeq(set_list, f, unescape(v1)) + self.addcmp(whe_list, f, unescape(v2)) + + q = "update only %s set %s where %s;" % ( + tbl, ", ".join(set_list), " and ".join(whe_list)) + self.show_fix(tbl, q, 'update') + + def got_missed_delete(self, tbl, dst_row, pkey_list): + self.cnt_delete += 1 + whe_list = [] + for f in self.pkey_list: + self.addcmp(whe_list, f, unescape(dst_row[f])) + q = "delete from only %s where %s;" % (tbl, " and ".join(whe_list)) + self.show_fix(tbl, q, 'delete') + + def show_fix(self, tbl, q, desc): + #self.log.warning("missed %s: %s" % (desc, q)) + fn = "fix.%s.sql" % tbl + open(fn, "a").write("%s\n" % q) + + def addeq(self, list, f, v): + vq = skytools.quote_literal(v) + s = "%s = %s" % (f, vq) + list.append(s) + + def addcmp(self, list, f, v): + if v is None: + s = "%s is null" % f + else: + vq = skytools.quote_literal(v) + s = "%s = %s" % (f, vq) + list.append(s) + + def cmp_data(self, src_row, dst_row): + for k in self.common_fields: + v1 = src_row[k] + v2 = dst_row[k] + if self.cmp_value(v1, v2) != 0: + return -1 + return 0 + + def cmp_value(self, v1, v2): + if v1 == v2: + return 0 + + # try to work around tz vs. notz + z1 = len(v1) + z2 = len(v2) + if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+': + v1 = v1[:-3] + if v1 == v2: + return 0 + elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+': + v2 = v2[:-3] + if v1 == v2: + return 0 + + return -1 + + def cmp_keys(self, src_row, dst_row): + """Compare primary keys of the rows. + + Returns 1 if src > dst, -1 if src < dst and 0 if src == dst""" + + # None means table is done. tag it larger than any existing row. + if src_row is None: + if dst_row is None: + return 0 + return 1 + elif dst_row is None: + return -1 + + for k in self.pkey_list: + v1 = src_row[k] + v2 = dst_row[k] + if v1 < v2: + return -1 + elif v1 > v2: + return 1 + return 0 + diff --git a/python/londiste/setup.py b/python/londiste/setup.py new file mode 100644 index 00000000..ed44b093 --- /dev/null +++ b/python/londiste/setup.py @@ -0,0 +1,580 @@ +#! /usr/bin/env python + +"""Londiste setup and sanity checker. + +""" +import sys, os, skytools +from installer import * + +__all__ = ['ProviderSetup', 'SubscriberSetup'] + +def find_column_types(curs, table): + table_oid = skytools.get_table_oid(curs, table) + if table_oid == None: + return None + + key_sql = """ + SELECT k.attname FROM pg_index i, pg_attribute k + WHERE i.indrelid = %d AND k.attrelid = i.indexrelid + AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped + """ % table_oid + + # find columns + q = """ + SELECT a.attname as name, + CASE WHEN k.attname IS NOT NULL + THEN 'k' ELSE 'v' END AS type + FROM pg_attribute a LEFT JOIN (%s) k ON (k.attname = a.attname) + WHERE a.attrelid = %d AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ % (key_sql, table_oid) + curs.execute(q) + rows = curs.dictfetchall() + return rows + +def make_type_string(col_rows): + res = map(lambda x: x['type'], col_rows) + return "".join(res) + +class CommonSetup(skytools.DBScript): + def __init__(self, args): + skytools.DBScript.__init__(self, 'londiste', args) + self.set_single_loop(1) + self.pidfile = self.pidfile + ".setup" + + self.pgq_queue_name = self.cf.get("pgq_queue_name") + self.consumer_id = self.cf.get("pgq_consumer_id", self.job_name) + self.fake = self.cf.getint('fake', 0) + + if len(self.args) < 3: + self.log.error("need subcommand") + sys.exit(1) + + def run(self): + self.admin() + + def fetch_provider_table_list(self, curs): + q = """select table_name, trigger_name + from londiste.provider_get_table_list(%s)""" + curs.execute(q, [self.pgq_queue_name]) + return curs.dictfetchall() + + def get_provider_table_list(self): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + list = self.fetch_provider_table_list(src_curs) + src_db.commit() + res = [] + for row in list: + res.append(row['table_name']) + return res + + def get_provider_seqs(self, curs): + q = """SELECT * from londiste.provider_get_seq_list(%s)""" + curs.execute(q, [self.pgq_queue_name]) + res = [] + for row in curs.fetchall(): + res.append(row[0]) + return res + + def get_all_seqs(self, curs): + q = """SELECT n.nspname || '.'|| c.relname + from pg_class c, pg_namespace n + where n.oid = c.relnamespace + and c.relkind = 'S' + order by 1""" + curs.execute(q) + res = [] + for row in curs.fetchall(): + res.append(row[0]) + return res + + def check_provider_queue(self): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + q = "select count(1) from pgq.get_queue_info(%s)" + src_curs.execute(q, [self.pgq_queue_name]) + ok = src_curs.fetchone()[0] + src_db.commit() + + if not ok: + self.log.error('Event queue does not exist yet') + sys.exit(1) + + def fetch_subscriber_tables(self, curs): + q = "select * from londiste.subscriber_get_table_list(%s)" + curs.execute(q, [self.pgq_queue_name]) + return curs.dictfetchall() + + def get_subscriber_table_list(self): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + list = self.fetch_subscriber_tables(dst_curs) + dst_db.commit() + res = [] + for row in list: + res.append(row['table_name']) + return res + + def init_optparse(self, parser=None): + p = skytools.DBScript.init_optparse(self, parser) + p.add_option("--expect-sync", action="store_true", dest="expect_sync", + help = "no copy needed", default=False) + p.add_option("--force", action="store_true", + help="force", default=False) + return p + + +# +# Provider commands +# + +class ProviderSetup(CommonSetup): + + def admin(self): + cmd = self.args[2] + if cmd == "tables": + self.provider_show_tables() + elif cmd == "add": + self.provider_add_tables(self.args[3:]) + elif cmd == "remove": + self.provider_remove_tables(self.args[3:]) + elif cmd == "add-seq": + for seq in self.args[3:]: + self.provider_add_seq(seq) + self.provider_notify_change() + elif cmd == "remove-seq": + for seq in self.args[3:]: + self.provider_remove_seq(seq) + self.provider_notify_change() + elif cmd == "install": + self.provider_install() + elif cmd == "seqs": + self.provider_list_seqs() + else: + self.log.error('bad subcommand') + sys.exit(1) + + def provider_list_seqs(self): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + list = self.get_provider_seqs(src_curs) + src_db.commit() + + for seq in list: + print seq + + def provider_install(self): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + install_provider(src_curs, self.log) + + # create event queue + q = "select pgq.create_queue(%s)" + self.exec_provider(q, [self.pgq_queue_name]) + + def provider_add_tables(self, table_list): + self.check_provider_queue() + + cur_list = self.get_provider_table_list() + for tbl in table_list: + if tbl.find('.') < 0: + tbl = "public." + tbl + if tbl not in cur_list: + self.log.info('Adding %s' % tbl) + self.provider_add_table(tbl) + else: + self.log.info("Table %s already added" % tbl) + self.provider_notify_change() + + def provider_remove_tables(self, table_list): + self.check_provider_queue() + + cur_list = self.get_provider_table_list() + for tbl in table_list: + if tbl.find('.') < 0: + tbl = "public." + tbl + if tbl not in cur_list: + self.log.info('%s already removed' % tbl) + else: + self.log.info("Removing %s" % tbl) + self.provider_remove_table(tbl) + self.provider_notify_change() + + def provider_add_table(self, tbl): + q = "select londiste.provider_add_table(%s, %s)" + self.exec_provider(q, [self.pgq_queue_name, tbl]) + + def provider_remove_table(self, tbl): + q = "select londiste.provider_remove_table(%s, %s)" + self.exec_provider(q, [self.pgq_queue_name, tbl]) + + def provider_show_tables(self): + self.check_provider_queue() + list = self.get_provider_table_list() + for tbl in list: + print tbl + + def provider_notify_change(self): + q = "select londiste.provider_notify_change(%s)" + self.exec_provider(q, [self.pgq_queue_name]) + + def provider_add_seq(self, seq): + seq = skytools.fq_name(seq) + q = "select londiste.provider_add_seq(%s, %s)" + self.exec_provider(q, [self.pgq_queue_name, seq]) + + def provider_remove_seq(self, seq): + seq = skytools.fq_name(seq) + q = "select londiste.provider_remove_seq(%s, %s)" + self.exec_provider(q, [self.pgq_queue_name, seq]) + + def exec_provider(self, sql, args): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + + src_curs.execute(sql, args) + + if self.fake: + src_db.rollback() + else: + src_db.commit() + +# +# Subscriber commands +# + +class SubscriberSetup(CommonSetup): + + def admin(self): + cmd = self.args[2] + if cmd == "tables": + self.subscriber_show_tables() + elif cmd == "missing": + self.subscriber_missing_tables() + elif cmd == "add": + self.subscriber_add_tables(self.args[3:]) + elif cmd == "remove": + self.subscriber_remove_tables(self.args[3:]) + elif cmd == "resync": + self.subscriber_resync_tables(self.args[3:]) + elif cmd == "register": + self.subscriber_register() + elif cmd == "unregister": + self.subscriber_unregister() + elif cmd == "install": + self.subscriber_install() + elif cmd == "check": + self.check_tables(self.get_provider_table_list()) + elif cmd == "fkeys": + self.collect_fkeys(self.get_provider_table_list()) + elif cmd == "seqs": + self.subscriber_list_seqs() + elif cmd == "add-seq": + self.subscriber_add_seq(self.args[3:]) + elif cmd == "remove-seq": + self.subscriber_remove_seq(self.args[3:]) + else: + self.log.error('bad subcommand: ' + cmd) + sys.exit(1) + + def collect_fkeys(self, table_list): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + + oid_list = [] + for tbl in table_list: + try: + oid = skytools.get_table_oid(dst_curs, tbl) + if oid: + oid_list.append(str(oid)) + except: + pass + if len(oid_list) == 0: + print "No tables" + return + oid_str = ",".join(oid_list) + + q = "SELECT n.nspname || '.' || t.relname as tbl, c.conname as con,"\ + " pg_get_constraintdef(c.oid) as def"\ + " FROM pg_constraint c, pg_class t, pg_namespace n"\ + " WHERE c.contype = 'f' and c.conrelid in (%s)"\ + " AND t.oid = c.conrelid AND n.oid = t.relnamespace" % oid_str + dst_curs.execute(q) + res = dst_curs.dictfetchall() + dst_db.commit() + + print "-- dropping" + for row in res: + q = "ALTER TABLE ONLY %(tbl)s DROP CONSTRAINT %(con)s;" + print q % row + print "-- creating" + for row in res: + q = "ALTER TABLE ONLY %(tbl)s ADD CONSTRAINT %(con)s %(def)s;" + print q % row + + def check_tables(self, table_list): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + + failed = 0 + for tbl in table_list: + self.log.info('Checking %s' % tbl) + if not skytools.exists_table(src_curs, tbl): + self.log.error('Table %s missing from provider side' % tbl) + failed += 1 + elif not skytools.exists_table(dst_curs, tbl): + self.log.error('Table %s missing from subscriber side' % tbl) + failed += 1 + else: + failed += self.check_table_columns(src_curs, dst_curs, tbl) + failed += self.check_table_triggers(dst_curs, tbl) + + src_db.commit() + dst_db.commit() + + return failed + + def check_table_triggers(self, dst_curs, tbl): + oid = skytools.get_table_oid(dst_curs, tbl) + if not oid: + self.log.error('Table %s not found' % tbl) + return 1 + q = "select count(1) from pg_trigger where tgrelid = %s" + dst_curs.execute(q, [oid]) + got = dst_curs.fetchone()[0] + if got: + self.log.error('found trigger on table %s (%s)' % (tbl, str(oid))) + return 1 + else: + return 0 + + def check_table_columns(self, src_curs, dst_curs, tbl): + src_colrows = find_column_types(src_curs, tbl) + dst_colrows = find_column_types(dst_curs, tbl) + + src_cols = make_type_string(src_colrows) + dst_cols = make_type_string(dst_colrows) + if src_cols.find('k') < 0: + self.log.error('provider table %s has no primary key (%s)' % ( + tbl, src_cols)) + return 1 + if dst_cols.find('k') < 0: + self.log.error('subscriber table %s has no primary key (%s)' % ( + tbl, dst_cols)) + return 1 + + if src_cols != dst_cols: + self.log.warning('table %s structure is not same (%s/%s)'\ + ', trying to continue' % (tbl, src_cols, dst_cols)) + + err = 0 + for row in src_colrows: + found = 0 + for row2 in dst_colrows: + if row2['name'] == row['name']: + found = 1 + break + if not found: + err = 1 + self.log.error('%s: column %s on provider not on subscriber' + % (tbl, row['name'])) + elif row['type'] != row2['type']: + err = 1 + self.log.error('%s: pk different on column %s' + % (tbl, row['name'])) + + return err + + def subscriber_install(self): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + + install_subscriber(dst_curs, self.log) + + if self.fake: + self.log.debug('rollback') + dst_db.rollback() + else: + self.log.debug('commit') + dst_db.commit() + + def subscriber_register(self): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + src_curs.execute("select pgq.register_consumer(%s, %s)", + [self.pgq_queue_name, self.consumer_id]) + src_db.commit() + + def subscriber_unregister(self): + q = "select londiste.subscriber_set_table_state(%s, %s, NULL, NULL)" + + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + tbl_rows = self.fetch_subscriber_tables(dst_curs) + for row in tbl_rows: + dst_curs.execute(q, [self.pgq_queue_name, row['table_name']]) + dst_db.commit() + + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + src_curs.execute("select pgq.unregister_consumer(%s, %s)", + [self.pgq_queue_name, self.consumer_id]) + src_db.commit() + + def subscriber_show_tables(self): + list = self.get_subscriber_table_list() + for tbl in list: + print tbl + + def subscriber_missing_tables(self): + provider_tables = self.get_provider_table_list() + subscriber_tables = self.get_subscriber_table_list() + for tbl in provider_tables: + if tbl not in subscriber_tables: + print tbl + + def subscriber_add_tables(self, table_list): + provider_tables = self.get_provider_table_list() + subscriber_tables = self.get_subscriber_table_list() + + err = 0 + for tbl in table_list: + tbl = skytools.fq_name(tbl) + if tbl not in provider_tables: + err = 1 + self.log.error("Table %s not attached to queue" % tbl) + if err: + if self.options.force: + self.log.warning('--force used, ignoring errors') + else: + sys.exit(1) + + err = self.check_tables(table_list) + if err: + if self.options.force: + self.log.warning('--force used, ignoring errors') + else: + sys.exit(1) + + for tbl in table_list: + tbl = skytools.fq_name(tbl) + if tbl in subscriber_tables: + self.log.info("Table %s already added" % tbl) + else: + self.log.info("Adding %s" % tbl) + self.subscriber_add_one_table(tbl) + + def subscriber_remove_tables(self, table_list): + subscriber_tables = self.get_subscriber_table_list() + for tbl in table_list: + tbl = skytools.fq_name(tbl) + if tbl in subscriber_tables: + self.subscriber_remove_one_table(tbl) + else: + self.log.info("Table %s already removed" % tbl) + + def subscriber_resync_tables(self, table_list): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + list = self.fetch_subscriber_tables(dst_curs) + for tbl in table_list: + tbl = skytools.fq_name(tbl) + tbl_row = None + for row in list: + if row['table_name'] == tbl: + tbl_row = row + break + if not tbl_row: + self.log.warning("Table %s not found" % tbl) + elif tbl_row['merge_state'] != 'ok': + self.log.warning("Table %s is not in stable state" % tbl) + else: + self.log.info("Resyncing %s" % tbl) + q = "select londiste.subscriber_set_table_state(%s, %s, NULL, NULL)" + dst_curs.execute(q, [self.pgq_queue_name, tbl]) + dst_db.commit() + + def subscriber_add_one_table(self, tbl): + q = "select londiste.subscriber_add_table(%s, %s)" + + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + dst_curs.execute(q, [self.pgq_queue_name, tbl]) + if self.options.expect_sync: + q = "select londiste.subscriber_set_table_state(%s, %s, null, 'ok')" + dst_curs.execute(q, [self.pgq_queue_name, tbl]) + dst_db.commit() + + def subscriber_remove_one_table(self, tbl): + q = "select londiste.subscriber_remove_table(%s, %s)" + + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + dst_curs.execute(q, [self.pgq_queue_name, tbl]) + dst_db.commit() + + def get_subscriber_seq_list(self): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + q = "SELECT * from londiste.subscriber_get_seq_list(%s)" + dst_curs.execute(q, [self.pgq_queue_name]) + list = dst_curs.fetchall() + dst_db.commit() + res = [] + for row in list: + res.append(row[0]) + return res + + def subscriber_list_seqs(self): + list = self.get_subscriber_seq_list() + for seq in list: + print seq + + def subscriber_add_seq(self, seq_list): + src_db = self.get_database('provider_db') + src_curs = src_db.cursor() + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + + prov_list = self.get_provider_seqs(src_curs) + src_db.commit() + + full_list = self.get_all_seqs(dst_curs) + cur_list = self.get_subscriber_seq_list() + + for seq in seq_list: + seq = skytools.fq_name(seq) + if seq not in prov_list: + self.log.error('Seq %s does not exist on provider side' % seq) + continue + if seq not in full_list: + self.log.error('Seq %s does not exist on subscriber side' % seq) + continue + if seq in cur_list: + self.log.info('Seq %s already subscribed' % seq) + continue + + self.log.info('Adding sequence: %s' % seq) + q = "select londiste.subscriber_add_seq(%s, %s)" + dst_curs.execute(q, [self.pgq_queue_name, seq]) + + dst_db.commit() + + def subscriber_remove_seq(self, seq_list): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + cur_list = self.get_subscriber_seq_list() + + for seq in seq_list: + seq = skytools.fq_name(seq) + if seq not in cur_list: + self.log.warning('Seq %s not subscribed') + else: + self.log.info('Removing sequence: %s' % seq) + q = "select londiste.subscriber_remove_seq(%s, %s)" + dst_curs.execute(q, [self.pgq_queue_name, seq]) + dst_db.commit() + diff --git a/python/londiste/syncer.py b/python/londiste/syncer.py new file mode 100644 index 00000000..eaee3468 --- /dev/null +++ b/python/londiste/syncer.py @@ -0,0 +1,177 @@ + +"""Catch moment when tables are in sync on master and slave. +""" + +import sys, time, skytools + +class Syncer(skytools.DBScript): + """Walks tables in primary key order and checks if data matches.""" + + def __init__(self, args): + skytools.DBScript.__init__(self, 'londiste', args) + self.set_single_loop(1) + + self.pgq_queue_name = self.cf.get("pgq_queue_name") + self.pgq_consumer_id = self.cf.get('pgq_consumer_id', self.job_name) + + if self.pidfile: + self.pidfile += ".repair" + + def init_optparse(self, p=None): + p = skytools.DBScript.init_optparse(self, p) + p.add_option("--force", action="store_true", help="ignore lag") + return p + + def check_consumer(self, src_db): + src_curs = src_db.cursor() + + # before locking anything check if consumer is working ok + q = "select extract(epoch from ticker_lag) from pgq.get_queue_list()"\ + " where queue_name = %s" + src_curs.execute(q, [self.pgq_queue_name]) + ticker_lag = src_curs.fetchone()[0] + q = "select extract(epoch from lag)"\ + " from pgq.get_consumer_list()"\ + " where queue_name = %s"\ + " and consumer_name = %s" + src_curs.execute(q, [self.pgq_queue_name, self.pgq_consumer_id]) + res = src_curs.fetchall() + src_db.commit() + + if len(res) == 0: + self.log.error('No such consumer') + sys.exit(1) + consumer_lag = res[0][0] + + if consumer_lag > ticker_lag + 10 and not self.options.force: + self.log.error('Consumer lagging too much, cannot proceed') + sys.exit(1) + + def get_subscriber_table_state(self): + dst_db = self.get_database('subscriber_db') + dst_curs = dst_db.cursor() + q = "select * from londiste.subscriber_get_table_list(%s)" + dst_curs.execute(q, [self.pgq_queue_name]) + res = dst_curs.dictfetchall() + dst_db.commit() + return res + + def work(self): + src_loc = self.cf.get('provider_db') + lock_db = self.get_database('provider_db', cache='lock_db') + src_db = self.get_database('provider_db') + dst_db = self.get_database('subscriber_db') + + self.check_consumer(src_db) + + state_list = self.get_subscriber_table_state() + state_map = {} + full_list = [] + for ts in state_list: + name = ts['table_name'] + full_list.append(name) + state_map[name] = ts + + if len(self.args) > 2: + tlist = self.args[2:] + else: + tlist = full_list + + for tbl in tlist: + if not tbl in state_map: + self.log.warning('Table not subscribed: %s' % tbl) + continue + st = state_map[tbl] + if st['merge_state'] != 'ok': + self.log.info('Table %s not synced yet, no point' % tbl) + continue + self.check_table(tbl, lock_db, src_db, dst_db) + lock_db.commit() + src_db.commit() + dst_db.commit() + + def check_table(self, tbl, lock_db, src_db, dst_db): + """Get transaction to same state, then process.""" + + + lock_curs = lock_db.cursor() + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + if not skytools.exists_table(src_curs, tbl): + self.log.warning("Table %s does not exist on provider side" % tbl) + return + if not skytools.exists_table(dst_curs, tbl): + self.log.warning("Table %s does not exist on subscriber side" % tbl) + return + + # lock table in separate connection + self.log.info('Locking %s' % tbl) + lock_db.commit() + lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % tbl) + lock_time = time.time() + + # now wait until consumer has updated target table until locking + self.log.info('Syncing %s' % tbl) + + # consumer must get futher than this tick + src_curs.execute("select pgq.ticker(%s)", [self.pgq_queue_name]) + tick_id = src_curs.fetchone()[0] + src_db.commit() + # avoid depending on ticker by inserting second tick also + time.sleep(0.1) + src_curs.execute("select pgq.ticker(%s)", [self.pgq_queue_name]) + src_db.commit() + src_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')") + tpos = src_curs.fetchone()[0] + src_db.commit() + # now wait + while 1: + time.sleep(0.2) + + q = """select now() - lag > %s, now(), lag + from pgq.get_consumer_list() + where consumer_name = %s + and queue_name = %s""" + src_curs.execute(q, [tpos, self.pgq_consumer_id, self.pgq_queue_name]) + res = src_curs.fetchall() + src_db.commit() + + if len(res) == 0: + raise Exception('No such consumer') + + row = res[0] + self.log.debug("tpos=%s now=%s lag=%s ok=%s" % (tpos, row[1], row[2], row[0])) + if row[0]: + break + + # loop max 10 secs + if time.time() > lock_time + 10 and not self.options.force: + self.log.error('Consumer lagging too much, exiting') + lock_db.rollback() + sys.exit(1) + + # take snapshot on provider side + src_curs.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") + src_curs.execute("SELECT 1") + + # take snapshot on subscriber side + dst_db.commit() + dst_curs.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE") + dst_curs.execute("SELECT 1") + + # release lock + lock_db.commit() + + # do work + self.process_sync(tbl, src_db, dst_db) + + # done + src_db.commit() + dst_db.commit() + + def process_sync(self, tbl, src_db, dst_db): + """It gets 2 connections in state where tbl should be in same state. + """ + raise Exception('process_sync not implemented') + diff --git a/python/londiste/table_copy.py b/python/londiste/table_copy.py new file mode 100644 index 00000000..1754baaf --- /dev/null +++ b/python/londiste/table_copy.py @@ -0,0 +1,107 @@ +#! /usr/bin/env python + +"""Do a full table copy. + +For internal usage. +""" + +import sys, os, skytools + +from skytools.dbstruct import * +from playback import * + +__all__ = ['CopyTable'] + +class CopyTable(Replicator): + def __init__(self, args, copy_thread = 1): + Replicator.__init__(self, args) + + if copy_thread: + self.pidfile += ".copy" + self.consumer_id += "_copy" + self.copy_thread = 1 + + def init_optparse(self, parser=None): + p = Replicator.init_optparse(self, parser) + p.add_option("--skip-truncate", action="store_true", dest="skip_truncate", + help = "avoid truncate", default=False) + return p + + def do_copy(self, tbl_stat): + src_db = self.get_database('provider_db') + dst_db = self.get_database('subscriber_db') + + # it should not matter to pgq + src_db.commit() + dst_db.commit() + + # change to SERIALIZABLE isolation level + src_db.set_isolation_level(2) + src_db.commit() + + # initial sync copy + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + self.log.info("Starting full copy of %s" % tbl_stat.name) + + # find dst struct + src_struct = TableStruct(src_curs, tbl_stat.name) + dst_struct = TableStruct(dst_curs, tbl_stat.name) + + # check if columns match + dlist = dst_struct.get_column_list() + for c in src_struct.get_column_list(): + if c not in dlist: + raise Exception('Column %s does not exist on dest side' % c) + + # drop unnecessary stuff + objs = T_CONSTRAINT | T_INDEX | T_TRIGGER | T_RULE + dst_struct.drop(dst_curs, objs, log = self.log) + + # do truncate & copy + self.real_copy(src_curs, dst_curs, tbl_stat.name) + + # get snapshot + src_curs.execute("select get_current_snapshot()") + snapshot = src_curs.fetchone()[0] + src_db.commit() + + # restore READ COMMITTED behaviour + src_db.set_isolation_level(1) + src_db.commit() + + # create previously dropped objects + dst_struct.create(dst_curs, objs, log = self.log) + + # set state + tbl_stat.change_snapshot(snapshot) + if self.copy_thread: + tbl_stat.change_state(TABLE_CATCHING_UP) + else: + tbl_stat.change_state(TABLE_OK) + self.save_table_state(dst_curs) + dst_db.commit() + + def real_copy(self, srccurs, dstcurs, tablename): + "Main copy logic." + + # drop data + if self.options.skip_truncate: + self.log.info("%s: skipping truncate" % tablename) + else: + self.log.info("%s: truncating" % tablename) + dstcurs.execute("truncate " + tablename) + + # do copy + self.log.info("%s: start copy" % tablename) + col_list = skytools.get_table_columns(srccurs, tablename) + stats = skytools.full_copy(tablename, srccurs, dstcurs, col_list) + if stats: + self.log.info("%s: copy finished: %d bytes, %d rows" % ( + tablename, stats[0], stats[1])) + +if __name__ == '__main__': + script = CopyTable(sys.argv[1:]) + script.start() + diff --git a/python/pgq/__init__.py b/python/pgq/__init__.py new file mode 100644 index 00000000..f0e9c1a6 --- /dev/null +++ b/python/pgq/__init__.py @@ -0,0 +1,6 @@ +"""PgQ framework for Python.""" + +from pgq.event import * +from pgq.consumer import * +from pgq.producer import * + diff --git a/python/pgq/consumer.py b/python/pgq/consumer.py new file mode 100644 index 00000000..bd49dccf --- /dev/null +++ b/python/pgq/consumer.py @@ -0,0 +1,410 @@ + +"""PgQ consumer framework for Python. + +API problems(?): + - process_event() and process_batch() should have db as argument. + - should ev.tag*() update db immidiately? + +""" + +import sys, time, skytools + +from pgq.event import * + +__all__ = ['Consumer', 'RemoteConsumer', 'SerialConsumer'] + +class Consumer(skytools.DBScript): + """Consumer base class. + """ + + def __init__(self, service_name, db_name, args): + """Initialize new consumer. + + @param service_name: service_name for DBScript + @param db_name: name of database for get_database() + @param args: cmdline args for DBScript + """ + + skytools.DBScript.__init__(self, service_name, args) + + self.db_name = db_name + self.reg_list = [] + self.consumer_id = self.cf.get("pgq_consumer_id", self.job_name) + self.pgq_queue_name = self.cf.get("pgq_queue_name") + + def attach(self): + """Attach consumer to interesting queues.""" + res = self.register_consumer(self.pgq_queue_name) + return res + + def detach(self): + """Detach consumer from all queues.""" + tmp = self.reg_list[:] + for q in tmp: + self.unregister_consumer(q) + + def process_event(self, db, event): + """Process one event. + + Should be overrided by user code. + + Event should be tagged as done, retry or failed. + If not, it will be tagged as for retry. + """ + raise Exception("needs to be implemented") + + def process_batch(self, db, batch_id, event_list): + """Process all events in batch. + + By default calls process_event for each. + Can be overrided by user code. + + Events should be tagged as done, retry or failed. + If not, they will be tagged as for retry. + """ + for ev in event_list: + self.process_event(db, ev) + + def work(self): + """Do the work loop, once (internal).""" + + if len(self.reg_list) == 0: + self.log.debug("Attaching") + self.attach() + + db = self.get_database(self.db_name) + curs = db.cursor() + + data_avail = 0 + for queue in self.reg_list: + self.stat_start() + + # acquire batch + batch_id = self._load_next_batch(curs, queue) + db.commit() + if batch_id == None: + continue + data_avail = 1 + + # load events + list = self._load_batch_events(curs, batch_id, queue) + db.commit() + + # process events + self._launch_process_batch(db, batch_id, list) + + # done + self._finish_batch(curs, batch_id, list) + db.commit() + self.stat_end(len(list)) + + # if false, script sleeps + return data_avail + + def register_consumer(self, queue_name): + db = self.get_database(self.db_name) + cx = db.cursor() + cx.execute("select pgq.register_consumer(%s, %s)", + [queue_name, self.consumer_id]) + res = cx.fetchone()[0] + db.commit() + + self.reg_list.append(queue_name) + + return res + + def unregister_consumer(self, queue_name): + db = self.get_database(self.db_name) + cx = db.cursor() + cx.execute("select pgq.unregister_consumer(%s, %s)", + [queue_name, self.consumer_id]) + db.commit() + + self.reg_list.remove(queue_name) + + def _launch_process_batch(self, db, batch_id, list): + self.process_batch(db, batch_id, list) + + def _load_batch_events(self, curs, batch_id, queue_name): + """Fetch all events for this batch.""" + + # load events + sql = "select * from pgq.get_batch_events(%d)" % batch_id + curs.execute(sql) + rows = curs.dictfetchall() + + # map them to python objects + list = [] + for r in rows: + ev = Event(queue_name, r) + list.append(ev) + + return list + + def _load_next_batch(self, curs, queue_name): + """Allocate next batch. (internal)""" + + q = "select pgq.next_batch(%s, %s)" + curs.execute(q, [queue_name, self.consumer_id]) + return curs.fetchone()[0] + + def _finish_batch(self, curs, batch_id, list): + """Tag events and notify that the batch is done.""" + + retry = failed = 0 + for ev in list: + if ev.status == EV_FAILED: + self._tag_failed(curs, batch_id, ev) + failed += 1 + elif ev.status == EV_RETRY: + self._tag_retry(curs, batch_id, ev) + retry += 1 + curs.execute("select pgq.finish_batch(%s)", [batch_id]) + + def _tag_failed(self, curs, batch_id, ev): + """Tag event as failed. (internal)""" + curs.execute("select pgq.event_failed(%s, %s, %s)", + [batch_id, ev.id, ev.fail_reason]) + + def _tag_retry(self, cx, batch_id, ev): + """Tag event for retry. (internal)""" + cx.execute("select pgq.event_retry(%s, %s, %s)", + [batch_id, ev.id, ev.retry_time]) + + def get_batch_info(self, batch_id): + """Get info about batch. + + @return: Return value is a dict of: + + - queue_name: queue name + - consumer_name: consumers name + - batch_start: batch start time + - batch_end: batch end time + - tick_id: end tick id + - prev_tick_id: start tick id + - lag: how far is batch_end from current moment. + """ + db = self.get_database(self.db_name) + cx = db.cursor() + q = "select queue_name, consumer_name, batch_start, batch_end,"\ + " prev_tick_id, tick_id, lag"\ + " from pgq.get_batch_info(%s)" + cx.execute(q, [batch_id]) + row = cx.dictfetchone() + db.commit() + return row + + def stat_start(self): + self.stat_batch_start = time.time() + + def stat_end(self, count): + t = time.time() + self.stat_add('count', count) + self.stat_add('duration', t - self.stat_batch_start) + + +class RemoteConsumer(Consumer): + """Helper for doing event processing in another database. + + Requires that whole batch is processed in one TX. + """ + + def __init__(self, service_name, db_name, remote_db, args): + Consumer.__init__(self, service_name, db_name, args) + self.remote_db = remote_db + + def process_batch(self, db, batch_id, event_list): + """Process all events in batch. + + By default calls process_event for each. + """ + dst_db = self.get_database(self.remote_db) + curs = dst_db.cursor() + + if self.is_last_batch(curs, batch_id): + for ev in event_list: + ev.tag_done() + return + + self.process_remote_batch(db, batch_id, event_list, dst_db) + + self.set_last_batch(curs, batch_id) + dst_db.commit() + + def is_last_batch(self, dst_curs, batch_id): + """Helper function to keep track of last successful batch + in external database. + """ + q = "select pgq_ext.is_batch_done(%s, %s)" + dst_curs.execute(q, [ self.consumer_id, batch_id ]) + return dst_curs.fetchone()[0] + + def set_last_batch(self, dst_curs, batch_id): + """Helper function to set last successful batch + in external database. + """ + q = "select pgq_ext.set_batch_done(%s, %s)" + dst_curs.execute(q, [ self.consumer_id, batch_id ]) + + def process_remote_batch(self, db, batch_id, event_list, dst_db): + raise Exception('process_remote_batch not implemented') + +class SerialConsumer(Consumer): + """Consumer that applies batches sequentially in second database. + + Requirements: + - Whole batch in one TX. + - Must not use retry queue. + + Features: + - Can detect if several batches are already applied to dest db. + - If some ticks are lost. allows to seek back on queue. + Whether it succeeds, depends on pgq configuration. + """ + + def __init__(self, service_name, db_name, remote_db, args): + Consumer.__init__(self, service_name, db_name, args) + self.remote_db = remote_db + self.dst_completed_table = "pgq_ext.completed_tick" + self.cur_batch_info = None + + def startup(self): + if self.options.rewind: + self.rewind() + sys.exit(0) + if self.options.reset: + self.dst_reset() + sys.exit(0) + return Consumer.startup(self) + + def init_optparse(self, parser = None): + p = Consumer.init_optparse(self, parser) + p.add_option("--rewind", action = "store_true", + help = "change queue position according to destination") + p.add_option("--reset", action = "store_true", + help = "reset queue pos on destination side") + return p + + def process_batch(self, db, batch_id, event_list): + """Process all events in batch. + """ + + dst_db = self.get_database(self.remote_db) + curs = dst_db.cursor() + + self.cur_batch_info = self.get_batch_info(batch_id) + + # check if done + if self.is_batch_done(curs): + for ev in event_list: + ev.tag_done() + return + + # actual work + self.process_remote_batch(db, batch_id, event_list, dst_db) + + # make sure no retry events + for ev in event_list: + if ev.status == EV_RETRY: + raise Exception("SerialConsumer must not use retry queue") + + # finish work + self.set_batch_done(curs) + dst_db.commit() + + def is_batch_done(self, dst_curs): + """Helper function to keep track of last successful batch + in external database. + """ + + prev_tick = self.cur_batch_info['prev_tick_id'] + + q = "select last_tick_id from %s where consumer_id = %%s" % ( + self.dst_completed_table ,) + dst_curs.execute(q, [self.consumer_id]) + res = dst_curs.fetchone() + + if not res or not res[0]: + # seems this consumer has not run yet against dst_db + return False + dst_tick = res[0] + + if prev_tick == dst_tick: + # on track + return False + + if prev_tick < dst_tick: + self.log.warning('Got tick %d, dst has %d - skipping' % (prev_tick, dst_tick)) + return True + else: + self.log.error('Got tick %d, dst has %d - ticks lost' % (prev_tick, dst_tick)) + raise Exception('Lost ticks') + + def set_batch_done(self, dst_curs): + """Helper function to set last successful batch + in external database. + """ + tick_id = self.cur_batch_info['tick_id'] + q = "delete from %s where consumer_id = %%s; "\ + "insert into %s (consumer_id, last_tick_id) values (%%s, %%s)" % ( + self.dst_completed_table, + self.dst_completed_table) + dst_curs.execute(q, [ self.consumer_id, + self.consumer_id, tick_id ]) + + def attach(self): + new = Consumer.attach(self) + if new: + self.clean_completed_tick() + + def detach(self): + """If detaching, also clean completed tick table on dest.""" + + Consumer.detach(self) + self.clean_completed_tick() + + def clean_completed_tick(self): + self.log.info("removing completed tick from dst") + dst_db = self.get_database(self.remote_db) + dst_curs = dst_db.cursor() + + q = "delete from %s where consumer_id = %%s" % ( + self.dst_completed_table,) + dst_curs.execute(q, [self.consumer_id]) + dst_db.commit() + + def process_remote_batch(self, db, batch_id, event_list, dst_db): + raise Exception('process_remote_batch not implemented') + + def rewind(self): + self.log.info("Rewinding queue") + src_db = self.get_database(self.db_name) + dst_db = self.get_database(self.remote_db) + src_curs = src_db.cursor() + dst_curs = dst_db.cursor() + + q = "select last_tick_id from %s where consumer_id = %%s" % ( + self.dst_completed_table,) + dst_curs.execute(q, [self.consumer_id]) + row = dst_curs.fetchone() + if row: + dst_tick = row[0] + q = "select pgq.register_consumer(%s, %s, %s)" + src_curs.execute(q, [self.pgq_queue_name, self.consumer_id, dst_tick]) + else: + self.log.warning('No tick found on dst side') + + dst_db.commit() + src_db.commit() + + def dst_reset(self): + self.log.info("Resetting queue tracking on dst side") + dst_db = self.get_database(self.remote_db) + dst_curs = dst_db.cursor() + + q = "delete from %s where consumer_id = %%s" % ( + self.dst_completed_table,) + dst_curs.execute(q, [self.consumer_id]) + dst_db.commit() + + diff --git a/python/pgq/event.py b/python/pgq/event.py new file mode 100644 index 00000000..d7b2d7ee --- /dev/null +++ b/python/pgq/event.py @@ -0,0 +1,60 @@ + +"""PgQ event container. +""" + +__all__ = ('EV_RETRY', 'EV_DONE', 'EV_FAILED', 'Event') + +# Event status codes +EV_RETRY = 0 +EV_DONE = 1 +EV_FAILED = 2 + +_fldmap = { + 'ev_id': 'ev_id', + 'ev_txid': 'ev_txid', + 'ev_time': 'ev_time', + 'ev_type': 'ev_type', + 'ev_data': 'ev_data', + 'ev_extra1': 'ev_extra1', + 'ev_extra2': 'ev_extra2', + 'ev_extra3': 'ev_extra3', + 'ev_extra4': 'ev_extra4', + + 'id': 'ev_id', + 'txid': 'ev_txid', + 'time': 'ev_time', + 'type': 'ev_type', + 'data': 'ev_data', + 'extra1': 'ev_extra1', + 'extra2': 'ev_extra2', + 'extra3': 'ev_extra3', + 'extra4': 'ev_extra4', +} + +class Event(object): + """Event data for consumers. + + Consumer is supposed to tag them after processing. + If not, events will stay in retry queue. + """ + def __init__(self, queue_name, row): + self._event_row = row + self.status = EV_RETRY + self.retry_time = 60 + self.fail_reason = "Buggy consumer" + self.queue_name = queue_name + + def __getattr__(self, key): + return self._event_row[_fldmap[key]] + + def tag_done(self): + self.status = EV_DONE + + def tag_retry(self, retry_time = 60): + self.status = EV_RETRY + self.retry_time = retry_time + + def tag_failed(self, reason): + self.status = EV_FAILED + self.fail_reason = reason + diff --git a/python/pgq/maint.py b/python/pgq/maint.py new file mode 100644 index 00000000..4636f74f --- /dev/null +++ b/python/pgq/maint.py @@ -0,0 +1,99 @@ +"""PgQ maintenance functions.""" + +import skytools, time + +def get_pgq_api_version(curs): + q = "select count(1) from pg_proc p, pg_namespace n"\ + " where n.oid = p.pronamespace and n.nspname='pgq'"\ + " and p.proname='version';" + curs.execute(q) + if not curs.fetchone()[0]: + return '1.0.0' + + curs.execute("select pgq.version()") + return curs.fetchone()[0] + +def version_ge(curs, want_ver): + """Check is db version of pgq is greater than want_ver.""" + db_ver = get_pgq_api_version(curs) + want_tuple = map(int, want_ver.split('.')) + db_tuple = map(int, db_ver.split('.')) + if db_tuple[0] != want_tuple[0]: + raise Exception('Wrong major version') + if db_tuple[1] >= want_tuple[1]: + return 1 + return 0 + +class MaintenanceJob(skytools.DBScript): + """Periodic maintenance.""" + def __init__(self, ticker, args): + skytools.DBScript.__init__(self, 'pgqadm', args) + self.ticker = ticker + self.last_time = 0 # start immidiately + self.last_ticks = 0 + self.clean_ticks = 1 + self.maint_delay = 5*60 + + def startup(self): + # disable regular DBScript startup() + pass + + def reload(self): + skytools.DBScript.reload(self) + + # force loop_delay + self.loop_delay = 5 + + self.maint_delay = 60 * self.cf.getfloat('maint_delay_min', 5) + self.maint_delay = self.cf.getfloat('maint_delay', self.maint_delay) + + def work(self): + t = time.time() + if self.last_time + self.maint_delay > t: + return + + self.do_maintenance() + + self.last_time = t + duration = time.time() - t + self.stat_add('maint_duration', duration) + + def do_maintenance(self): + """Helper function for running maintenance.""" + + db = self.get_database('db', autocommit=1) + cx = db.cursor() + + if skytools.exists_function(cx, "pgq.maint_rotate_tables_step1", 1): + # rotate each queue in own TX + q = "select queue_name from pgq.get_queue_info()" + cx.execute(q) + for row in cx.fetchall(): + cx.execute("select pgq.maint_rotate_tables_step1(%s)", [row[0]]) + res = cx.fetchone()[0] + if res: + self.log.info('Rotating %s' % row[0]) + else: + cx.execute("select pgq.maint_rotate_tables_step1();") + + # finish rotation + cx.execute("select pgq.maint_rotate_tables_step2();") + + # move retry events to main queue in small blocks + rcount = 0 + while 1: + cx.execute('select pgq.maint_retry_events();') + res = cx.fetchone()[0] + rcount += res + if res == 0: + break + if rcount: + self.log.info('Got %d events for retry' % rcount) + + # vacuum tables that are needed + cx.execute('set maintenance_work_mem = 32768') + cx.execute('select * from pgq.maint_tables_to_vacuum()') + for row in cx.fetchall(): + cx.execute('vacuum %s;' % row[0]) + + diff --git a/python/pgq/producer.py b/python/pgq/producer.py new file mode 100644 index 00000000..81e1ca4f --- /dev/null +++ b/python/pgq/producer.py @@ -0,0 +1,41 @@ + +"""PgQ producer helpers for Python. +""" + +import skytools + +_fldmap = { + 'id': 'ev_id', + 'time': 'ev_time', + 'type': 'ev_type', + 'data': 'ev_data', + 'extra1': 'ev_extra1', + 'extra2': 'ev_extra2', + 'extra3': 'ev_extra3', + 'extra4': 'ev_extra4', + + 'ev_id': 'ev_id', + 'ev_time': 'ev_time', + 'ev_type': 'ev_type', + 'ev_data': 'ev_data', + 'ev_extra1': 'ev_extra1', + 'ev_extra2': 'ev_extra2', + 'ev_extra3': 'ev_extra3', + 'ev_extra4': 'ev_extra4', +} + +def bulk_insert_events(curs, rows, fields, queue_name): + q = "select pgq.current_event_table(%s)" + curs.execute(q, [queue_name]) + tbl = curs.fetchone()[0] + db_fields = map(_fldmap.get, fields) + skytools.magic_insert(curs, tbl, rows, db_fields) + +def insert_event(curs, queue, ev_type, ev_data, + extra1=None, extra2=None, + extra3=None, extra4=None): + q = "select pgq.insert_event(%s, %s, %s, %s, %s, %s, %s)" + curs.execute(q, [queue, ev_type, ev_data, + extra1, extra2, extra3, extra4]) + return curs.fetchone()[0] + diff --git a/python/pgq/status.py b/python/pgq/status.py new file mode 100644 index 00000000..2214045f --- /dev/null +++ b/python/pgq/status.py @@ -0,0 +1,93 @@ + +"""Status display. +""" + +import sys, os, skytools + +def ival(data, as = None): + "Format interval for output" + if not as: + as = data.split('.')[-1] + numfmt = 'FM9999999' + expr = "coalesce(to_char(extract(epoch from %s), '%s') || 's', 'NULL') as %s" + return expr % (data, numfmt, as) + +class PGQStatus(skytools.DBScript): + def __init__(self, args, check = 0): + skytools.DBScript.__init__(self, 'pgqadm', args) + + self.show_status() + + sys.exit(0) + + def show_status(self): + db = self.get_database("db", autocommit=1) + cx = db.cursor() + + cx.execute("show server_version") + pgver = cx.fetchone()[0] + cx.execute("select pgq.version()") + qver = cx.fetchone()[0] + print "Postgres version: %s PgQ version: %s" % (pgver, qver) + + q = """select f.queue_name, f.num_tables, %s, %s, %s, + q.queue_ticker_max_lag, q.queue_ticker_max_amount, + q.queue_ticker_idle_interval + from pgq.get_queue_info() f, pgq.queue q + where q.queue_name = f.queue_name""" % ( + ival('f.rotation_delay'), + ival('f.ticker_lag'), + ) + cx.execute(q) + event_rows = cx.dictfetchall() + + q = """select queue_name, consumer_name, %s, %s, %s + from pgq.get_consumer_info()""" % ( + ival('lag'), + ival('last_seen'), + ) + cx.execute(q) + consumer_rows = cx.dictfetchall() + + print "\n%-32s %s %9s %13s %6s" % ('Event queue', + 'Rotation', 'Ticker', 'TLag') + print '-' * 78 + for ev_row in event_rows: + tck = "%s/%ss/%ss" % (ev_row['queue_ticker_max_amount'], + ev_row['queue_ticker_max_lag'], + ev_row['queue_ticker_idle_interval']) + rot = "%s/%s" % (ev_row['queue_ntables'], ev_row['queue_rotation_period']) + print "%-39s%7s %9s %13s %6s" % ( + ev_row['queue_name'], + rot, + tck, + ev_row['ticker_lag'], + ) + print '-' * 78 + print "\n%-42s %9s %9s" % ( + 'Consumer', 'Lag', 'LastSeen') + print '-' * 78 + for ev_row in event_rows: + cons = self.pick_consumers(ev_row, consumer_rows) + self.show_queue(ev_row, cons) + print '-' * 78 + db.commit() + + def show_consumer(self, cons): + print " %-48s %9s %9s" % ( + cons['consumer_name'], + cons['lag'], cons['last_seen']) + def show_queue(self, ev_row, consumer_rows): + print "%(queue_name)s:" % ev_row + for cons in consumer_rows: + self.show_consumer(cons) + + + def pick_consumers(self, ev_row, consumer_rows): + res = [] + for con in consumer_rows: + if con['queue_name'] != ev_row['queue_name']: + continue + res.append(con) + return res + diff --git a/python/pgq/ticker.py b/python/pgq/ticker.py new file mode 100644 index 00000000..c218eaf1 --- /dev/null +++ b/python/pgq/ticker.py @@ -0,0 +1,172 @@ +"""PgQ ticker. + +It will also launch maintenance job. +""" + +import sys, os, time, threading +import skytools + +from maint import MaintenanceJob + +__all__ = ['SmartTicker'] + +def is_txid_sane(curs): + curs.execute("select get_current_txid()") + txid = curs.fetchone()[0] + + # on 8.2 theres no such table + if not skytools.exists_table(curs, 'txid.epoch'): + return 1 + + curs.execute("select epoch, last_value from txid.epoch") + epoch, last_val = curs.fetchone() + stored_val = (epoch << 32) | last_val + + if stored_val <= txid: + return 1 + else: + return 0 + +class QueueStatus(object): + def __init__(self, name): + self.queue_name = name + self.seq_name = None + self.idle_period = 60 + self.max_lag = 3 + self.max_count = 200 + self.last_tick_time = 0 + self.last_count = 0 + self.quiet_count = 0 + + def set_data(self, row): + self.seq_name = row['queue_event_seq'] + self.idle_period = row['queue_ticker_idle_period'] + self.max_lag = row['queue_ticker_max_lag'] + self.max_count = row['queue_ticker_max_count'] + + def need_tick(self, cur_count, cur_time): + # check if tick is needed + need_tick = 0 + lag = cur_time - self.last_tick_time + + if cur_count == self.last_count: + # totally idle database + + # don't go immidiately to big delays, as seq grows before commit + if self.quiet_count < 5: + if lag >= self.max_lag: + need_tick = 1 + self.quiet_count += 1 + else: + if lag >= self.idle_period: + need_tick = 1 + else: + self.quiet_count = 0 + # somewhat loaded machine + if cur_count - self.last_count >= self.max_count: + need_tick = 1 + elif lag >= self.max_lag: + need_tick = 1 + if need_tick: + self.last_tick_time = cur_time + self.last_count = cur_count + return need_tick + +class SmartTicker(skytools.DBScript): + last_tick_event = 0 + last_tick_time = 0 + quiet_count = 0 + tick_count = 0 + maint_thread = None + + def __init__(self, args): + skytools.DBScript.__init__(self, 'pgqadm', args) + + self.ticker_log_time = 0 + self.ticker_log_delay = 5*60 + self.queue_map = {} + self.refresh_time = 0 + + def reload(self): + skytools.DBScript.reload(self) + self.ticker_log_delay = self.cf.getfloat("ticker_log_delay", 5*60) + + def startup(self): + if self.maint_thread: + return + + db = self.get_database("db", autocommit = 1) + cx = db.cursor() + ok = is_txid_sane(cx) + if not ok: + self.log.error('txid in bad state') + sys.exit(1) + + self.maint_thread = MaintenanceJob(self, [self.cf.filename]) + t = threading.Thread(name = 'maint_thread', + target = self.maint_thread.run) + t.setDaemon(1) + t.start() + + def refresh_queues(self, cx): + q = "select queue_name, queue_event_seq, queue_ticker_idle_period,"\ + " queue_ticker_max_lag, queue_ticker_max_count"\ + " from pgq.queue"\ + " where not queue_external_ticker" + cx.execute(q) + new_map = {} + data_list = [] + from_list = [] + for row in cx.dictfetchall(): + queue_name = row['queue_name'] + try: + que = self.queue_map[queue_name] + except KeyError, x: + que = QueueStatus(queue_name) + que.set_data(row) + new_map[queue_name] = que + + p1 = "'%s', %s.last_value" % (queue_name, que.seq_name) + data_list.append(p1) + from_list.append(que.seq_name) + + self.queue_map = new_map + self.seq_query = "select %s from %s" % ( + ", ".join(data_list), + ", ".join(from_list)) + + if len(from_list) == 0: + self.seq_query = None + + self.refresh_time = time.time() + + def work(self): + db = self.get_database("db", autocommit = 1) + cx = db.cursor() + + cur_time = time.time() + + if cur_time >= self.refresh_time + 30: + self.refresh_queues(cx) + + if not self.seq_query: + return + + # now check seqs + cx.execute(self.seq_query) + res = cx.fetchone() + pos = 0 + while pos < len(res): + id = res[pos] + val = res[pos + 1] + pos += 2 + que = self.queue_map[id] + if que.need_tick(val, cur_time): + cx.execute("select pgq.ticker(%s)", [que.queue_name]) + self.tick_count += 1 + + if cur_time > self.ticker_log_time + self.ticker_log_delay: + self.ticker_log_time = cur_time + self.stat_add('ticks', self.tick_count) + self.tick_count = 0 + diff --git a/python/pgqadm.py b/python/pgqadm.py new file mode 100755 index 00000000..78f513dc --- /dev/null +++ b/python/pgqadm.py @@ -0,0 +1,162 @@ +#! /usr/bin/env python + +"""PgQ ticker and maintenance. +""" + +import sys +import skytools + +from pgq.ticker import SmartTicker +from pgq.status import PGQStatus +#from pgq.admin import PGQAdmin + +"""TODO: +pgqadm ini check +""" + +command_usage = """ +%prog [options] INI CMD [subcmd args] + +commands: + ticker start ticking & maintenance process + + status show overview of queue health + check show problematic consumers + + install install code into db + create QNAME create queue + drop QNAME drop queue + register QNAME CONS install code into db + unregister QNAME CONS install code into db + config QNAME [VAR=VAL] show or change queue config +""" + +config_allowed_list = [ + 'queue_ticker_max_lag', 'queue_ticker_max_amount', + 'queue_ticker_idle_interval', 'queue_rotation_period'] + +class PGQAdmin(skytools.DBScript): + def __init__(self, args): + skytools.DBScript.__init__(self, 'pgqadm', args) + self.set_single_loop(1) + + if len(self.args) < 2: + print "need command" + sys.exit(1) + + int_cmds = { + 'create': self.create_queue, + 'drop': self.drop_queue, + 'register': self.register, + 'unregister': self.unregister, + 'install': self.installer, + 'config': self.change_config, + } + + cmd = self.args[1] + if cmd == "ticker": + script = SmartTicker(args) + elif cmd == "status": + script = PGQStatus(args) + elif cmd == "check": + script = PGQStatus(args, check = 1) + elif cmd in int_cmds: + script = None + self.work = int_cmds[cmd] + else: + print "unknown command" + sys.exit(1) + + if self.pidfile: + self.pidfile += ".admin" + self.run_script = script + + def start(self): + if self.run_script: + self.run_script.start() + else: + skytools.DBScript.start(self) + + def init_optparse(self, parser=None): + p = skytools.DBScript.init_optparse(self, parser) + p.set_usage(command_usage.strip()) + return p + + def installer(self): + objs = [ + skytools.DBLanguage("plpgsql"), + skytools.DBLanguage("plpythonu"), + skytools.DBFunction("get_current_txid", 0, sql_file="txid.sql"), + skytools.DBSchema("pgq", sql_file="pgq.sql"), + ] + + db = self.get_database('db') + curs = db.cursor() + skytools.db_install(curs, objs, self.log) + db.commit() + + def create_queue(self): + qname = self.args[2] + self.log.info('Creating queue: %s' % qname) + self.exec_sql("select pgq.create_queue(%s)", [qname]) + + def drop_queue(self): + qname = self.args[2] + self.log.info('Dropping queue: %s' % qname) + self.exec_sql("select pgq.drop_queue(%s)", [qname]) + + def register(self): + qname = self.args[2] + cons = self.args[3] + self.log.info('Registering consumer %s on queue %s' % (cons, qname)) + self.exec_sql("select pgq.register_consumer(%s, %s)", [qname, cons]) + + def unregister(self): + qname = self.args[2] + cons = self.args[3] + self.log.info('Unregistering consumer %s from queue %s' % (cons, qname)) + self.exec_sql("select pgq.unregister_consumer(%s, %s)", [qname, cons]) + + def change_config(self): + qname = self.args[2] + if len(self.args) == 3: + self.show_config(qname) + return + alist = [] + for el in self.args[3:]: + k, v = el.split('=') + if k not in config_allowed_list: + raise Exception('unknown config var: '+k) + expr = "%s=%s" % (k, skytools.quote_literal(v)) + alist.append(expr) + self.log.info('Change queue %s config to: %s' % (qname, ", ".join(alist))) + sql = "update pgq.queue set %s where queue_name = %s" % ( + ", ".join(alist), skytools.quote_literal(qname)) + self.exec_sql(sql, []) + + def exec_sql(self, q, args): + self.log.debug(q) + db = self.get_database('db') + curs = db.cursor() + curs.execute(q, args) + db.commit() + + def show_config(self, qname): + klist = ",".join(config_allowed_list) + q = "select * from pgq.queue where queue_name = %s" + db = self.get_database('db') + curs = db.cursor() + curs.execute(q, [qname]) + res = curs.dictfetchone() + db.commit() + + print qname + for k in config_allowed_list: + print " %s=%s" % (k, res[k]) + +if __name__ == '__main__': + script = PGQAdmin(sys.argv[1:]) + script.start() + + + diff --git a/python/skytools/__init__.py b/python/skytools/__init__.py new file mode 100644 index 00000000..ed2b39bc --- /dev/null +++ b/python/skytools/__init__.py @@ -0,0 +1,10 @@ + +"""Tools for Python database scripts.""" + +from config import * +from dbstruct import * +from gzlog import * +from quoting import * +from scripting import * +from sqltools import * + diff --git a/python/skytools/config.py b/python/skytools/config.py new file mode 100644 index 00000000..de420322 --- /dev/null +++ b/python/skytools/config.py @@ -0,0 +1,139 @@ + +"""Nicer config class.""" + +import sys, os, ConfigParser, socket + +__all__ = ['Config'] + +class Config(object): + """Bit improved ConfigParser. + + Additional features: + - Remembers section. + - Acceps defaults in get() functions. + - List value support. + """ + def __init__(self, main_section, filename, sane_config = 1): + """Initialize Config and read from file. + + @param sane_config: chooses between ConfigParser/SafeConfigParser. + """ + defs = { + 'job_name': main_section, + 'service_name': main_section, + 'host_name': socket.gethostname(), + } + if not os.path.isfile(filename): + raise Exception('Config file not found: '+filename) + + self.filename = filename + self.sane_config = sane_config + if sane_config: + self.cf = ConfigParser.SafeConfigParser(defs) + else: + self.cf = ConfigParser.ConfigParser(defs) + self.cf.read(filename) + self.main_section = main_section + if not self.cf.has_section(main_section): + raise Exception("Wrong config file, no section '%s'"%main_section) + + def reload(self): + """Re-reads config file.""" + self.cf.read(self.filename) + + def get(self, key, default=None): + """Reads string value, if not set then default.""" + try: + return self.cf.get(self.main_section, key) + except ConfigParser.NoOptionError, det: + if default == None: + raise Exception("Config value not set: " + key) + return default + + def getint(self, key, default=None): + """Reads int value, if not set then default.""" + try: + return self.cf.getint(self.main_section, key) + except ConfigParser.NoOptionError, det: + if default == None: + raise Exception("Config value not set: " + key) + return default + + def getboolean(self, key, default=None): + """Reads boolean value, if not set then default.""" + try: + return self.cf.getboolean(self.main_section, key) + except ConfigParser.NoOptionError, det: + if default == None: + raise Exception("Config value not set: " + key) + return default + + def getfloat(self, key, default=None): + """Reads float value, if not set then default.""" + try: + return self.cf.getfloat(self.main_section, key) + except ConfigParser.NoOptionError, det: + if default == None: + raise Exception("Config value not set: " + key) + return default + + def getlist(self, key, default=None): + """Reads comma-separated list from key.""" + try: + s = self.cf.get(self.main_section, key).strip() + res = [] + if not s: + return res + for v in s.split(","): + res.append(v.strip()) + return res + except ConfigParser.NoOptionError, det: + if default == None: + raise Exception("Config value not set: " + key) + return default + + def getfile(self, key, default=None): + """Reads filename from config. + + In addition to reading string value, expands ~ to user directory. + """ + fn = self.get(key, default) + if fn == "" or fn == "-": + return fn + # simulate that the cwd is script location + #path = os.path.dirname(sys.argv[0]) + # seems bad idea, cwd should be cwd + + fn = os.path.expanduser(fn) + + return fn + + def get_wildcard(self, key, values=[], default=None): + """Reads a wildcard property from conf and returns its string value, if not set then default.""" + + orig_key = key + keys = [key] + + for wild in values: + key = key.replace('*', wild, 1) + keys.append(key) + keys.reverse() + + for key in keys: + try: + return self.cf.get(self.main_section, key) + except ConfigParser.NoOptionError, det: + pass + + if default == None: + raise Exception("Config value not set: " + orig_key) + return default + + def sections(self): + """Returns list of sections in config file, excluding DEFAULT.""" + return self.cf.sections() + + def clone(self, main_section): + """Return new Config() instance with new main section on same config file.""" + return Config(main_section, self.filename, self.sane_config) + diff --git a/python/skytools/dbstruct.py b/python/skytools/dbstruct.py new file mode 100644 index 00000000..22333429 --- /dev/null +++ b/python/skytools/dbstruct.py @@ -0,0 +1,380 @@ +"""Find table structure and allow CREATE/DROP elements from it. +""" + +import sys, re + +from sqltools import fq_name_parts, get_table_oid + +__all__ = ['TableStruct', + 'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER', + 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL'] + +T_TABLE = 1 << 0 +T_CONSTRAINT = 1 << 1 +T_INDEX = 1 << 2 +T_TRIGGER = 1 << 3 +T_RULE = 1 << 4 +T_GRANT = 1 << 5 +T_OWNER = 1 << 6 +T_PKEY = 1 << 20 # special, one of constraints +T_ALL = ( T_TABLE | T_CONSTRAINT | T_INDEX + | T_TRIGGER | T_RULE | T_GRANT | T_OWNER ) + +# +# Utility functions +# + +def find_new_name(curs, name): + """Create new object name for case the old exists. + + Needed when creating a new table besides old one. + """ + # cut off previous numbers + m = re.search('_[0-9]+$', name) + if m: + name = name[:m.start()] + + # now loop + for i in range(1, 1000): + tname = "%s_%d" % (name, i) + q = "select count(1) from pg_class where relname = %s" + curs.execute(q, [tname]) + if curs.fetchone()[0] == 0: + return tname + + # failed + raise Exception('find_new_name failed') + +def rx_replace(rx, sql, new_part): + """Find a regex match and replace that part with new_part.""" + m = re.search(rx, sql, re.I) + if not m: + raise Exception('rx_replace failed') + p1 = sql[:m.start()] + p2 = sql[m.end():] + return p1 + new_part + p2 + +# +# Schema objects +# + +class TElem(object): + """Keeps info about one metadata object.""" + SQL = "" + type = 0 + def get_create_sql(self, curs): + """Return SQL statement for creating or None if not supported.""" + return None + def get_drop_sql(self, curs): + """Return SQL statement for dropping or None of not supported.""" + return None + +class TConstraint(TElem): + """Info about constraint.""" + type = T_CONSTRAINT + SQL = """ + SELECT conname as name, pg_get_constraintdef(oid) as def, contype + FROM pg_constraint WHERE conrelid = %(oid)s + """ + def __init__(self, table_name, row): + self.table_name = table_name + self.name = row['name'] + self.defn = row['def'] + self.contype = row['contype'] + + # tag pkeys + if self.contype == 'p': + self.type += T_PKEY + + def get_create_sql(self, curs, new_table_name=None): + fmt = "ALTER TABLE ONLY %s ADD CONSTRAINT %s %s;" + if new_table_name: + name = self.name + if self.contype in ('p', 'u'): + name = find_new_name(curs, self.name) + sql = fmt % (new_table_name, name, self.defn) + else: + sql = fmt % (self.table_name, self.name, self.defn) + return sql + + def get_drop_sql(self, curs): + fmt = "ALTER TABLE ONLY %s DROP CONSTRAINT %s;" + sql = fmt % (self.table_name, self.name) + return sql + +class TIndex(TElem): + """Info about index.""" + type = T_INDEX + SQL = """ + SELECT n.nspname || '.' || c.relname as name, + pg_get_indexdef(i.indexrelid) as defn + FROM pg_index i, pg_class c, pg_namespace n + WHERE c.oid = i.indexrelid AND i.indrelid = %(oid)s + AND n.oid = c.relnamespace + AND NOT EXISTS + (select objid from pg_depend + where classid = %(pg_class_oid)s + and objid = c.oid + and deptype = 'i') + """ + def __init__(self, table_name, row): + self.name = row['name'] + self.defn = row['defn'] + ';' + + def get_create_sql(self, curs, new_table_name = None): + if not new_table_name: + return self.defn + name = find_new_name(curs, self.name) + pnew = "INDEX %s ON %s " % (name, new_table_name) + rx = r"\bINDEX[ ][a-z0-9._]+[ ]ON[ ][a-z0-9._]+[ ]" + sql = rx_replace(rx, self.defn, pnew) + return sql + def get_drop_sql(self, curs): + return 'DROP INDEX %s;' % self.name + +class TRule(TElem): + """Info about rule.""" + type = T_RULE + SQL = """ + SELECT rulename as name, pg_get_ruledef(oid) as def + FROM pg_rewrite + WHERE ev_class = %(oid)s AND rulename <> '_RETURN'::name + """ + def __init__(self, table_name, row, new_name = None): + self.table_name = table_name + self.name = row['name'] + self.defn = row['def'] + + def get_create_sql(self, curs, new_table_name = None): + if not new_table_name: + return self.defn + rx = r"\bTO[ ][a-z0-9._]+[ ]DO[ ]" + pnew = "TO %s DO " % new_table_name + return rx_replace(rx, self.defn, pnew) + + def get_drop_sql(self, curs): + return 'DROP RULE %s ON %s' % (self.name, self.table_name) + +class TTrigger(TElem): + """Info about trigger.""" + type = T_TRIGGER + SQL = """ + SELECT tgname as name, pg_get_triggerdef(oid) as def + FROM pg_trigger + WHERE tgrelid = %(oid)s AND NOT tgisconstraint + """ + def __init__(self, table_name, row): + self.table_name = table_name + self.name = row['name'] + self.defn = row['def'] + ';' + + def get_create_sql(self, curs, new_table_name = None): + if not new_table_name: + return self.defn + + rx = r"\bON[ ][a-z0-9._]+[ ]" + pnew = "ON %s " % new_table_name + return rx_replace(rx, self.defn, pnew) + + def get_drop_sql(self, curs): + return 'DROP TRIGGER %s ON %s' % (self.name, self.table_name) + +class TOwner(TElem): + """Info about table owner.""" + type = T_OWNER + SQL = """ + SELECT pg_get_userbyid(relowner) as owner FROM pg_class + WHERE oid = %(oid)s + """ + def __init__(self, table_name, row, new_name = None): + self.table_name = table_name + self.name = 'Owner' + self.owner = row['owner'] + + def get_create_sql(self, curs, new_name = None): + if not new_name: + new_name = self.table_name + return 'ALTER TABLE %s OWNER TO %s;' % (new_name, self.owner) + +class TGrant(TElem): + """Info about permissions.""" + type = T_GRANT + SQL = "SELECT relacl FROM pg_class where oid = %(oid)s" + acl_map = { + 'r': 'SELECT', 'w': 'UPDATE', 'a': 'INSERT', 'd': 'DELETE', + 'R': 'RULE', 'x': 'REFERENCES', 't': 'TRIGGER', 'X': 'EXECUTE', + 'U': 'USAGE', 'C': 'CREATE', 'T': 'TEMPORARY' + } + def acl_to_grants(self, acl): + if acl == "arwdRxt": # ALL for tables + return "ALL" + return ", ".join([ self.acl_map[c] for c in acl ]) + + def parse_relacl(self, relacl): + if relacl is None: + return [] + if len(relacl) > 0 and relacl[0] == '{' and relacl[-1] == '}': + relacl = relacl[1:-1] + list = [] + for f in relacl.split(','): + user, tmp = f.strip('"').split('=') + acl, who = tmp.split('/') + list.append((user, acl, who)) + return list + + def __init__(self, table_name, row, new_name = None): + self.name = table_name + self.acl_list = self.parse_relacl(row['relacl']) + + def get_create_sql(self, curs, new_name = None): + if not new_name: + new_name = self.name + + list = [] + for user, acl, who in self.acl_list: + astr = self.acl_to_grants(acl) + sql = "GRANT %s ON %s TO %s;" % (astr, new_name, user) + list.append(sql) + return "\n".join(list) + + def get_drop_sql(self, curs): + list = [] + for user, acl, who in self.acl_list: + sql = "REVOKE ALL FROM %s ON %s;" % (user, self.name) + list.append(sql) + return "\n".join(list) + +class TColumn(TElem): + """Info about table column.""" + SQL = """ + select a.attname as name, + a.attname || ' ' + || format_type(a.atttypid, a.atttypmod) + || case when a.attnotnull then ' not null' else '' end + || case when a.atthasdef then ' ' || d.adsrc else '' end + as def + from pg_attribute a left join pg_attrdef d + on (d.adrelid = a.attrelid and d.adnum = a.attnum) + where a.attrelid = %(oid)s + and not a.attisdropped + and a.attnum > 0 + order by a.attnum; + """ + def __init__(self, table_name, row): + self.name = row['name'] + self.column_def = row['def'] + +class TTable(TElem): + """Info about table only (columns).""" + type = T_TABLE + def __init__(self, table_name, col_list): + self.name = table_name + self.col_list = col_list + + def get_create_sql(self, curs, new_name = None): + if not new_name: + new_name = self.name + sql = "create table %s (" % new_name + sep = "\n\t" + for c in self.col_list: + sql += sep + c.column_def + sep = ",\n\t" + sql += "\n);" + return sql + + def get_drop_sql(self, curs): + return "DROP TABLE %s;" % self.name + +# +# Main table object, loads all the others +# + +class TableStruct(object): + """Collects and manages all info about table. + + Allow to issue CREATE/DROP statements about any + group of elements. + """ + def __init__(self, curs, table_name): + """Initializes class by loading info about table_name from database.""" + + self.table_name = table_name + + # fill args + schema, name = fq_name_parts(table_name) + args = { + 'schema': schema, + 'table': name, + 'oid': get_table_oid(curs, table_name), + 'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'), + } + + # load table struct + self.col_list = self._load_elem(curs, args, TColumn) + self.object_list = [ TTable(table_name, self.col_list) ] + + # load additional objects + to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner] + for eclass in to_load: + self.object_list += self._load_elem(curs, args, eclass) + + def _load_elem(self, curs, args, eclass): + list = [] + curs.execute(eclass.SQL % args) + for row in curs.dictfetchall(): + list.append(eclass(self.table_name, row)) + return list + + def create(self, curs, objs, new_table_name = None, log = None): + """Issues CREATE statements for requested set of objects. + + If new_table_name is giver, creates table under that name + and also tries to rename all indexes/constraints that conflict + with existing table. + """ + + for o in self.object_list: + if o.type & objs: + sql = o.get_create_sql(curs, new_table_name) + if not sql: + continue + if log: + log.info('Creating %s' % o.name) + log.debug(sql) + curs.execute(sql) + + def drop(self, curs, objs, log = None): + """Issues DROP statements for requested set of objects.""" + for o in self.object_list: + if o.type & objs: + sql = o.get_drop_sql(curs) + if not sql: + continue + if log: + log.info('Dropping %s' % o.name) + log.debug(sql) + curs.execute(sql) + + def get_column_list(self): + """Returns list of column names the table has.""" + + res = [] + for c in self.col_list: + res.append(c.name) + return res + +def test(): + import psycopg + db = psycopg.connect("dbname=fooz") + curs = db.cursor() + + s = TableStruct(curs, "public.data1") + + s.drop(curs, T_ALL) + s.create(curs, T_ALL) + s.create(curs, T_ALL, "data1_new") + s.create(curs, T_PKEY) + +if __name__ == '__main__': + test() + diff --git a/python/skytools/gzlog.py b/python/skytools/gzlog.py new file mode 100644 index 00000000..558e2813 --- /dev/null +++ b/python/skytools/gzlog.py @@ -0,0 +1,39 @@ + +"""Atomic append of gzipped data. + +The point is - if several gzip streams are concated, they +are read back as one whose stream. +""" + +import gzip +from cStringIO import StringIO + +__all__ = ['gzip_append'] + +# +# gzip storage +# +def gzip_append(filename, data, level = 6): + """Append a block of data to file with safety checks.""" + + # compress data + buf = StringIO() + g = gzip.GzipFile(fileobj = buf, compresslevel = level, mode = "w") + g.write(data) + g.close() + zdata = buf.getvalue() + + # append, safely + f = open(filename, "a+", 0) + f.seek(0, 2) + pos = f.tell() + try: + f.write(zdata) + f.close() + except Exception, ex: + # rollback on error + f.seek(pos, 0) + f.truncate() + f.close() + raise ex + diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py new file mode 100644 index 00000000..96b0b022 --- /dev/null +++ b/python/skytools/quoting.py @@ -0,0 +1,156 @@ +# quoting.py + +"""Various helpers for string quoting/unquoting.""" + +import psycopg, urllib, re + +# +# SQL quoting +# + +def quote_literal(s): + """Quote a literal value for SQL. + + Surronds it with single-quotes. + """ + + if s == None: + return "null" + s = psycopg.QuotedString(str(s)) + return str(s) + +def quote_copy(s): + """Quoting for copy command.""" + + if s == None: + return "\\N" + s = str(s) + s = s.replace("\\", "\\\\") + s = s.replace("\t", "\\t") + s = s.replace("\n", "\\n") + s = s.replace("\r", "\\r") + return s + +def quote_bytea_raw(s): + """Quoting for bytea parser.""" + + if s == None: + return None + return s.replace("\\", "\\\\").replace("\0", "\\000") + +def quote_bytea_literal(s): + """Quote bytea for regular SQL.""" + + return quote_literal(quote_bytea_raw(s)) + +def quote_bytea_copy(s): + """Quote bytea for COPY.""" + + return quote_copy(quote_bytea_raw(s)) + +def quote_statement(sql, dict): + """Quote whose statement. + + Data values are taken from dict. + """ + xdict = {} + for k, v in dict.items(): + xdict[k] = quote_literal(v) + return sql % xdict + +# +# quoting for JSON strings +# + +_jsre = re.compile(r'[\x00-\x1F\\/"]') +_jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", + "\t": "\\t", "\\": "\\\\", '"': '\\"', + "/": "\\/", # to avoid html attacks +} + +def _json_quote_char(m): + c = m.group(0) + try: + return _jsmap[c] + except KeyError: + return r"\u%04x" % ord(c) + +def quote_json(s): + """JSON style quoting.""" + if s is None: + return "null" + return '"%s"' % _jsre.sub(_json_quote_char, s) + +# +# Database specific urlencode and urldecode. +# + +def db_urlencode(dict): + """Database specific urlencode. + + Encode None as key without '='. That means that in "foo&bar=", + foo is NULL and bar is empty string. + """ + + elem_list = [] + for k, v in dict.items(): + if v is None: + elem = urllib.quote_plus(str(k)) + else: + elem = urllib.quote_plus(str(k)) + '=' + urllib.quote_plus(str(v)) + elem_list.append(elem) + return '&'.join(elem_list) + +def db_urldecode(qs): + """Database specific urldecode. + + Decode key without '=' as None. + This also does not support one key several times. + """ + + res = {} + for elem in qs.split('&'): + if not elem: + continue + pair = elem.split('=', 1) + name = urllib.unquote_plus(pair[0]) + if len(pair) == 1: + res[name] = None + else: + res[name] = urllib.unquote_plus(pair[1]) + return res + +# +# Remove C-like backslash escapes +# + +_esc_re = r"\\([0-7][0-7][0-7]|.)" +_esc_rc = re.compile(_esc_re) +_esc_map = { + 't': '\t', + 'n': '\n', + 'r': '\r', + 'a': '\a', + 'b': '\b', + "'": "'", + '"': '"', + '\\': '\\', +} + +def _sub_unescape(m): + v = m.group(1) + if len(v) == 1: + return _esc_map[v] + else: + return chr(int(v, 8)) + +def unescape(val): + """Removes C-style escapes from string.""" + return _esc_rc.sub(_sub_unescape, val) + +def unescape_copy(val): + """Removes C-style escapes, also converts "\N" to None.""" + if val == r"\N": + return None + return unescape(val) + diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py new file mode 100644 index 00000000..cf976801 --- /dev/null +++ b/python/skytools/scripting.py @@ -0,0 +1,523 @@ + +"""Useful functions and classes for database scripts.""" + +import sys, os, signal, psycopg, optparse, traceback, time +import logging, logging.handlers, logging.config + +from skytools.config import * +import skytools.skylog + +__all__ = ['daemonize', 'run_single_process', 'DBScript', + 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE'] + +# +# daemon mode +# + +def daemonize(): + """Turn the process into daemon. + + Goes background and disables all i/o. + """ + + # launch new process, kill parent + pid = os.fork() + if pid != 0: + os._exit(0) + + # start new session + os.setsid() + + # stop i/o + fd = os.open("/dev/null", os.O_RDWR) + os.dup2(fd, 0) + os.dup2(fd, 1) + os.dup2(fd, 2) + if fd > 2: + os.close(fd) + +# +# Pidfile locking+cleanup & daemonization combined +# + +def _write_pidfile(pidfile): + pid = os.getpid() + f = open(pidfile, 'w') + f.write(str(pid)) + f.close() + +def run_single_process(runnable, daemon, pidfile): + """Run runnable class, possibly daemonized, locked on pidfile.""" + + # check if another process is running + if pidfile and os.path.isfile(pidfile): + print "Pidfile exists, another process running?" + sys.exit(1) + + # daemonize if needed and write pidfile + if daemon: + daemonize() + if pidfile: + _write_pidfile(pidfile) + + # Catch SIGTERM to cleanup pidfile + def sigterm_hook(signum, frame): + try: + os.remove(pidfile) + except: pass + sys.exit(0) + # attach it to signal + if pidfile: + signal.signal(signal.SIGTERM, sigterm_hook) + + # run + try: + runnable.run() + finally: + # another try of cleaning up + if pidfile: + try: + os.remove(pidfile) + except: pass + +# +# logging setup +# + +_log_config_done = 0 +_log_init_done = {} + +def _init_log(job_name, cf, log_level): + """Logging setup happens here.""" + global _log_init_done, _log_config_done + + got_skylog = 0 + use_skylog = cf.getint("use_skylog", 0) + + # load logging config if needed + if use_skylog and not _log_config_done: + # python logging.config braindamage: + # cannot specify external classess without such hack + logging.skylog = skytools.skylog + + # load general config + list = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini'] + for fn in list: + fn = os.path.expanduser(fn) + if os.path.isfile(fn): + defs = {'job_name': job_name} + logging.config.fileConfig(fn, defs) + got_skylog = 1 + break + _log_config_done = 1 + if not got_skylog: + sys.stderr.write("skylog.ini not found!\n") + sys.exit(1) + + # avoid duplicate logging init for job_name + log = logging.getLogger(job_name) + if job_name in _log_init_done: + return log + _log_init_done[job_name] = 1 + + # compatibility: specify ini file in script config + logfile = cf.getfile("logfile", "") + if logfile: + fmt = logging.Formatter('%(asctime)s %(process)s %(levelname)s %(message)s') + size = cf.getint('log_size', 10*1024*1024) + num = cf.getint('log_count', 3) + hdlr = logging.handlers.RotatingFileHandler( + logfile, 'a', size, num) + hdlr.setFormatter(fmt) + log.addHandler(hdlr) + + # if skylog.ini is disabled or not available, log at least to stderr + if not got_skylog: + hdlr = logging.StreamHandler() + fmt = logging.Formatter('%(asctime)s %(process)s %(levelname)s %(message)s') + hdlr.setFormatter(fmt) + log.addHandler(hdlr) + + log.setLevel(log_level) + + return log + +#: how old connections need to be closed +DEF_CONN_AGE = 20*60 # 20 min + +#: isolation level not set +I_DEFAULT = -1 + +#: isolation level constant for AUTOCOMMIT +I_AUTOCOMMIT = 0 +#: isolation level constant for READ COMMITTED +I_READ_COMMITTED = 1 +#: isolation level constant for SERIALIZABLE +I_SERIALIZABLE = 2 + +class DBCachedConn(object): + """Cache a db connection.""" + def __init__(self, name, loc, max_age = DEF_CONN_AGE): + self.name = name + self.loc = loc + self.conn = None + self.conn_time = 0 + self.max_age = max_age + self.autocommit = -1 + self.isolation_level = -1 + + def get_connection(self, autocommit = 0, isolation_level = -1): + # autocommit overrider isolation_level + if autocommit: + isolation_level = I_AUTOCOMMIT + + # default isolation_level is READ COMMITTED + if isolation_level < 0: + isolation_level = I_READ_COMMITTED + + # new conn? + if not self.conn: + self.isolation_level = isolation_level + self.conn = psycopg.connect(self.loc) + + self.conn.set_isolation_level(isolation_level) + self.conn_time = time.time() + else: + if self.isolation_level != isolation_level: + raise Exception("Conflict in isolation_level") + + # done + return self.conn + + def refresh(self): + if not self.conn: + return + #for row in self.conn.notifies(): + # if row[0].lower() == "reload": + # self.reset() + # return + if not self.max_age: + return + if time.time() - self.conn_time >= self.max_age: + self.reset() + + def reset(self): + if not self.conn: + return + + # drop reference + conn = self.conn + self.conn = None + + if self.isolation_level == I_AUTOCOMMIT: + return + + # rollback & close + try: + conn.rollback() + except: pass + try: + conn.close() + except: pass + +class DBScript(object): + """Base class for database scripts. + + Handles logging, daemonizing, config, errors. + """ + service_name = None + job_name = None + cf = None + log = None + + def __init__(self, service_name, args): + """Script setup. + + User class should override work() and optionally __init__(), startup(), + reload(), reset() and init_optparse(). + + NB: in case of daemon, the __init__() and startup()/work() will be + run in different processes. So nothing fancy should be done in __init__(). + + @param service_name: unique name for script. + It will be also default job_name, if not specified in config. + @param args: cmdline args (sys.argv[1:]), but can be overrided + """ + self.service_name = service_name + self.db_cache = {} + self.go_daemon = 0 + self.do_single_loop = 0 + self.looping = 1 + self.need_reload = 1 + self.stat_dict = {} + self.log_level = logging.INFO + self.work_state = 1 + + # parse command line + parser = self.init_optparse() + self.options, self.args = parser.parse_args(args) + + # check args + if self.options.daemon: + self.go_daemon = 1 + if self.options.quiet: + self.log_level = logging.WARNING + if self.options.verbose: + self.log_level = logging.DEBUG + if len(self.args) < 1: + print "need config file" + sys.exit(1) + conf_file = self.args[0] + + # load config + self.cf = Config(self.service_name, conf_file) + self.job_name = self.cf.get("job_name", self.service_name) + self.pidfile = self.cf.getfile("pidfile", '') + + self.reload() + + # init logging + self.log = _init_log(self.job_name, self.cf, self.log_level) + + # send signal, if needed + if self.options.cmd == "kill": + self.send_signal(signal.SIGTERM) + elif self.options.cmd == "stop": + self.send_signal(signal.SIGINT) + elif self.options.cmd == "reload": + self.send_signal(signal.SIGHUP) + + def init_optparse(self, parser = None): + """Initialize a OptionParser() instance that will be used to + parse command line arguments. + + Note that it can be overrided both directions - either DBScript + will initialize a instance and passes to user code or user can + initialize and then pass to DBScript.init_optparse(). + + @param parser: optional OptionParser() instance, + where DBScript should attachs its own arguments. + @return: initialized OptionParser() instance. + """ + if parser: + p = parser + else: + p = optparse.OptionParser() + p.set_usage("%prog [options] INI") + # generic options + p.add_option("-q", "--quiet", action="store_true", + help = "make program silent") + p.add_option("-v", "--verbose", action="store_true", + help = "make program verbose") + p.add_option("-d", "--daemon", action="store_true", + help = "go background") + + # control options + g = optparse.OptionGroup(p, 'control running process') + g.add_option("-r", "--reload", + action="store_const", const="reload", dest="cmd", + help = "reload config (send SIGHUP)") + g.add_option("-s", "--stop", + action="store_const", const="stop", dest="cmd", + help = "stop program safely (send SIGINT)") + g.add_option("-k", "--kill", + action="store_const", const="kill", dest="cmd", + help = "kill program immidiately (send SIGTERM)") + p.add_option_group(g) + + return p + + def send_signal(self, sig): + if not self.pidfile: + self.log.warning("No pidfile in config, nothing todo") + sys.exit(0) + if not os.path.isfile(self.pidfile): + self.log.warning("No pidfile, process not running") + sys.exit(0) + pid = int(open(self.pidfile, "r").read()) + os.kill(pid, sig) + sys.exit(0) + + def set_single_loop(self, do_single_loop): + """Changes whether the script will loop or not.""" + self.do_single_loop = do_single_loop + + def start(self): + """This will launch main processing thread.""" + if self.go_daemon: + if not self.pidfile: + self.log.error("Daemon needs pidfile") + sys.exit(1) + run_single_process(self, self.go_daemon, self.pidfile) + + def stop(self): + """Safely stops processing loop.""" + self.looping = 0 + + def reload(self): + "Reload config." + self.cf.reload() + self.loop_delay = self.cf.getfloat("loop_delay", 1.0) + + def hook_sighup(self, sig, frame): + "Internal SIGHUP handler. Minimal code here." + self.need_reload = 1 + + def hook_sigint(self, sig, frame): + "Internal SIGINT handler. Minimal code here." + self.stop() + + def stat_add(self, key, value): + self.stat_put(key, value) + + def stat_put(self, key, value): + """Sets a stat value.""" + self.stat_dict[key] = value + + def stat_increase(self, key, increase = 1): + """Increases a stat value.""" + if key in self.stat_dict: + self.stat_dict[key] += increase + else: + self.stat_dict[key] = increase + + def send_stats(self): + "Send statistics to log." + + res = [] + for k, v in self.stat_dict.items(): + res.append("%s: %s" % (k, str(v))) + + if len(res) == 0: + return + + logmsg = "{%s}" % ", ".join(res) + self.log.info(logmsg) + self.stat_dict = {} + + def get_database(self, dbname, autocommit = 0, isolation_level = -1, + cache = None, max_age = DEF_CONN_AGE): + """Load cached database connection. + + User must not store it permanently somewhere, + as all connections will be invalidated on reset. + """ + + if not cache: + cache = dbname + if cache in self.db_cache: + dbc = self.db_cache[cache] + else: + loc = self.cf.get(dbname) + dbc = DBCachedConn(cache, loc, max_age) + self.db_cache[cache] = dbc + + return dbc.get_connection(autocommit, isolation_level) + + def close_database(self, dbname): + """Explicitly close a cached connection. + + Next call to get_database() will reconnect. + """ + if dbname in self.db_cache: + dbc = self.db_cache[dbname] + dbc.reset() + + def reset(self): + "Something bad happened, reset all connections." + for dbc in self.db_cache.values(): + dbc.reset() + self.db_cache = {} + + def run(self): + "Thread main loop." + + # run startup, safely + try: + self.startup() + except KeyboardInterrupt, det: + raise + except SystemExit, det: + raise + except Exception, det: + exc, msg, tb = sys.exc_info() + self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % ( + self.job_name, str(exc), str(msg).rstrip(), + str(tb), repr(traceback.format_tb(tb)))) + del tb + self.reset() + sys.exit(1) + + while self.looping: + # reload config, if needed + if self.need_reload: + self.reload() + self.need_reload = 0 + + # do some work + work = self.run_once() + + # send stats that was added + self.send_stats() + + # reconnect if needed + for dbc in self.db_cache.values(): + dbc.refresh() + + # exit if needed + if self.do_single_loop: + self.log.debug("Only single loop requested, exiting") + break + + # remember work state + self.work_state = work + # should sleep? + if not work: + try: + time.sleep(self.loop_delay) + except Exception, d: + self.log.debug("sleep failed: "+str(d)) + sys.exit(0) + + def run_once(self): + "Run users work function, safely." + try: + return self.work() + except SystemExit, d: + self.send_stats() + self.log.info("got SystemExit(%s), exiting" % str(d)) + self.reset() + raise d + except KeyboardInterrupt, d: + self.send_stats() + self.log.info("got KeyboardInterrupt, exiting") + self.reset() + sys.exit(1) + except Exception, d: + self.send_stats() + exc, msg, tb = sys.exc_info() + self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % ( + self.job_name, str(exc), str(msg).rstrip(), + str(tb), repr(traceback.format_tb(tb)))) + del tb + self.reset() + if self.looping: + time.sleep(20) + return 1 + + def work(self): + "Here should user's processing happen." + raise Exception("Nothing implemented?") + + def startup(self): + """Will be called just before entering main loop. + + In case of daemon, if will be called in same process as work(), + unlike __init__(). + """ + + # set signals + signal.signal(signal.SIGHUP, self.hook_sighup) + signal.signal(signal.SIGINT, self.hook_sigint) + + diff --git a/python/skytools/skylog.py b/python/skytools/skylog.py new file mode 100644 index 00000000..2f6344ae --- /dev/null +++ b/python/skytools/skylog.py @@ -0,0 +1,173 @@ +"""Our log handlers for Python's logging package. +""" + +import sys, os, time, socket, psycopg +import logging, logging.handlers + +from quoting import quote_json + +# configurable file logger +class EasyRotatingFileHandler(logging.handlers.RotatingFileHandler): + """Easier setup for RotatingFileHandler.""" + def __init__(self, filename, maxBytes = 10*1024*1024, backupCount = 3): + """Args same as for RotatingFileHandler, but in filename '~' is expanded.""" + fn = os.path.expanduser(filename) + logging.handlers.RotatingFileHandler.__init__(self, fn, maxBytes=maxBytes, backupCount=backupCount) + +# send JSON message over UDP +class UdpLogServerHandler(logging.handlers.DatagramHandler): + """Sends log records over UDP to logserver in JSON format.""" + + # map logging levels to logserver levels + _level_map = { + logging.DEBUG : 'DEBUG', + logging.INFO : 'INFO', + logging.WARNING : 'WARN', + logging.ERROR : 'ERROR', + logging.CRITICAL: 'FATAL', + } + + # JSON message template + _log_template = '{\n\t'\ + '"logger": "skytools.UdpLogServer",\n\t'\ + '"timestamp": %.0f,\n\t'\ + '"level": "%s",\n\t'\ + '"thread": null,\n\t'\ + '"message": %s,\n\t'\ + '"properties": {"application":"%s", "hostname":"%s"}\n'\ + '}' + + # cut longer msgs + MAXMSG = 1024 + + def makePickle(self, record): + """Create message in JSON format.""" + # get & cut msg + msg = self.format(record) + if len(msg) > self.MAXMSG: + msg = msg[:self.MAXMSG] + txt_level = self._level_map.get(record.levelno, "ERROR") + pkt = self._log_template % (time.time()*1000, txt_level, + quote_json(msg), record.name, socket.gethostname()) + return pkt + +class LogDBHandler(logging.handlers.SocketHandler): + """Sends log records into PostgreSQL server. + + Additionally, does some statistics aggregating, + to avoid overloading log server. + + It subclasses SocketHandler to get throtthling for + failed connections. + """ + + # map codes to string + _level_map = { + logging.DEBUG : 'DEBUG', + logging.INFO : 'INFO', + logging.WARNING : 'WARNING', + logging.ERROR : 'ERROR', + logging.CRITICAL: 'FATAL', + } + + def __init__(self, connect_string): + """ + Initializes the handler with a specific connection string. + """ + + logging.handlers.SocketHandler.__init__(self, None, None) + self.closeOnError = 1 + + self.connect_string = connect_string + + self.stat_cache = {} + self.stat_flush_period = 60 + # send first stat line immidiately + self.last_stat_flush = 0 + + def createSocket(self): + try: + logging.handlers.SocketHandler.createSocket(self) + except: + self.sock = self.makeSocket() + + def makeSocket(self): + """Create server connection. + In this case its not socket but psycopg conection.""" + + db = psycopg.connect(self.connect_string) + db.autocommit(1) + return db + + def emit(self, record): + """Process log record.""" + + # we do not want log debug messages + if record.levelno < logging.INFO: + return + + try: + self.process_rec(record) + except (SystemExit, KeyboardInterrupt): + raise + except: + self.handleError(record) + + def process_rec(self, record): + """Aggregate stats if needed, and send to logdb.""" + # render msg + msg = self.format(record) + + # dont want to send stats too ofter + if record.levelno == logging.INFO and msg and msg[0] == "{": + self.aggregate_stats(msg) + if time.time() - self.last_stat_flush >= self.stat_flush_period: + self.flush_stats(record.name) + return + + if record.levelno < logging.INFO: + self.flush_stats(record.name) + + # dont send more than one line + ln = msg.find('\n') + if ln > 0: + msg = msg[:ln] + + txt_level = self._level_map.get(record.levelno, "ERROR") + self.send_to_logdb(record.name, txt_level, msg) + + def aggregate_stats(self, msg): + """Sum stats together, to lessen load on logdb.""" + + msg = msg[1:-1] + for rec in msg.split(", "): + k, v = rec.split(": ") + agg = self.stat_cache.get(k, 0) + if v.find('.') >= 0: + agg += float(v) + else: + agg += int(v) + self.stat_cache[k] = agg + + def flush_stats(self, service): + """Send awuired stats to logdb.""" + res = [] + for k, v in self.stat_cache.items(): + res.append("%s: %s" % (k, str(v))) + if len(res) > 0: + logmsg = "{%s}" % ", ".join(res) + self.send_to_logdb(service, "INFO", logmsg) + self.stat_cache = {} + self.last_stat_flush = time.time() + + def send_to_logdb(self, service, type, msg): + """Actual sending is done here.""" + + if self.sock is None: + self.createSocket() + + if self.sock: + logcur = self.sock.cursor() + query = "select * from log.add(%s, %s, %s)" + logcur.execute(query, [type, service, msg]) + diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py new file mode 100644 index 00000000..75e209f1 --- /dev/null +++ b/python/skytools/sqltools.py @@ -0,0 +1,398 @@ + +"""Database tools.""" + +import os +from cStringIO import StringIO +from quoting import quote_copy, quote_literal + +# +# Fully qualified table name +# + +def fq_name_parts(tbl): + "Return fully qualified name parts." + + tmp = tbl.split('.') + if len(tmp) == 1: + return ('public', tbl) + elif len(tmp) == 2: + return tmp + else: + raise Exception('Syntax error in table name:'+tbl) + +def fq_name(tbl): + "Return fully qualified name." + return '.'.join(fq_name_parts(tbl)) + +# +# info about table +# +def get_table_oid(curs, table_name): + schema, name = fq_name_parts(table_name) + q = """select c.oid from pg_namespace n, pg_class c + where c.relnamespace = n.oid + and n.nspname = %s and c.relname = %s""" + curs.execute(q, [schema, name]) + res = curs.fetchall() + if len(res) == 0: + raise Exception('Table not found: '+table_name) + return res[0][0] + +def get_table_pkeys(curs, tbl): + oid = get_table_oid(curs, tbl) + q = "SELECT k.attname FROM pg_index i, pg_attribute k"\ + " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\ + " AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped"\ + " ORDER BY k.attnum" + curs.execute(q, [oid]) + return map(lambda x: x[0], curs.fetchall()) + +def get_table_columns(curs, tbl): + oid = get_table_oid(curs, tbl) + q = "SELECT k.attname FROM pg_attribute k"\ + " WHERE k.attrelid = %s"\ + " AND k.attnum > 0 AND NOT k.attisdropped"\ + " ORDER BY k.attnum" + curs.execute(q, [oid]) + return map(lambda x: x[0], curs.fetchall()) + +# +# exist checks +# +def exists_schema(curs, schema): + q = "select count(1) from pg_namespace where nspname = %s" + curs.execute(q, [schema]) + res = curs.fetchone() + return res[0] + +def exists_table(curs, table_name): + schema, name = fq_name_parts(table_name) + q = """select count(1) from pg_namespace n, pg_class c + where c.relnamespace = n.oid and c.relkind = 'r' + and n.nspname = %s and c.relname = %s""" + curs.execute(q, [schema, name]) + res = curs.fetchone() + return res[0] + +def exists_type(curs, type_name): + schema, name = fq_name_parts(type_name) + q = """select count(1) from pg_namespace n, pg_type t + where t.typnamespace = n.oid + and n.nspname = %s and t.typname = %s""" + curs.execute(q, [schema, name]) + res = curs.fetchone() + return res[0] + +def exists_function(curs, function_name, nargs): + # this does not check arg types, so may match several functions + schema, name = fq_name_parts(function_name) + q = """select count(1) from pg_namespace n, pg_proc p + where p.pronamespace = n.oid and p.pronargs = %s + and n.nspname = %s and p.proname = %s""" + curs.execute(q, [nargs, schema, name]) + res = curs.fetchone() + return res[0] + +def exists_language(curs, lang_name): + q = """select count(1) from pg_language + where lanname = %s""" + curs.execute(q, [lang_name]) + res = curs.fetchone() + return res[0] + +# +# Support for PostgreSQL snapshot +# + +class Snapshot(object): + "Represents a PostgreSQL snapshot." + + def __init__(self, str): + "Create snapshot from string." + + self.sn_str = str + tmp = str.split(':') + if len(tmp) != 3: + raise Exception('Unknown format for snapshot') + self.xmin = int(tmp[0]) + self.xmax = int(tmp[1]) + self.txid_list = [] + if tmp[2] != "": + for s in tmp[2].split(','): + self.txid_list.append(int(s)) + + def contains(self, txid): + "Is txid visible in snapshot." + + txid = int(txid) + + if txid < self.xmin: + return True + if txid >= self.xmax: + return False + if txid in self.txid_list: + return False + return True + +# +# Copy helpers +# + +def _gen_dict_copy(tbl, row, fields): + tmp = [] + for f in fields: + v = row[f] + tmp.append(quote_copy(v)) + return "\t".join(tmp) + +def _gen_dict_insert(tbl, row, fields): + tmp = [] + for f in fields: + v = row[f] + tmp.append(quote_literal(v)) + fmt = "insert into %s (%s) values (%s);" + return fmt % (tbl, ",".join(fields), ",".join(tmp)) + +def _gen_list_copy(tbl, row, fields): + tmp = [] + for i in range(len(fields)): + v = row[i] + tmp.append(quote_copy(v)) + return "\t".join(tmp) + +def _gen_list_insert(tbl, row, fields): + tmp = [] + for i in range(len(fields)): + v = row[i] + tmp.append(quote_literal(v)) + fmt = "insert into %s (%s) values (%s);" + return fmt % (tbl, ",".join(fields), ",".join(tmp)) + +def magic_insert(curs, tablename, data, fields = None, use_insert = 0): + """Copy/insert a list of dict/list data to database. + + If curs == None, then the copy or insert statements are returned + as string. For list of dict the field list is optional, as its + possible to guess them from dict keys. + """ + if len(data) == 0: + return + + # decide how to process + if type(data[0]) == type({}): + if fields == None: + fields = data[0].keys() + if use_insert: + row_func = _gen_dict_insert + else: + row_func = _gen_dict_copy + else: + if fields == None: + raise Exception("Non-dict data needs field list") + if use_insert: + row_func = _gen_list_insert + else: + row_func = _gen_list_copy + + # init processing + buf = StringIO() + if curs == None and use_insert == 0: + fmt = "COPY %s (%s) FROM STDIN;\n" + buf.write(fmt % (tablename, ",".join(fields))) + + # process data + for row in data: + buf.write(row_func(tablename, row, fields)) + buf.write("\n") + + # if user needs only string, return it + if curs == None: + if use_insert == 0: + buf.write("\\.\n") + return buf.getvalue() + + # do the actual copy/inserts + if use_insert: + curs.execute(buf.getvalue()) + else: + buf.seek(0) + hdr = "%s (%s)" % (tablename, ",".join(fields)) + curs.copy_from(buf, hdr) + +def db_copy_from_dict(curs, tablename, dict_list, fields = None): + """Do a COPY FROM STDIN using list of dicts as source.""" + + if len(dict_list) == 0: + return + + if fields == None: + fields = dict_list[0].keys() + + buf = StringIO() + for dat in dict_list: + row = [] + for k in fields: + row.append(quote_copy(dat[k])) + buf.write("\t".join(row)) + buf.write("\n") + + buf.seek(0) + hdr = "%s (%s)" % (tablename, ",".join(fields)) + + curs.copy_from(buf, hdr) + +def db_copy_from_list(curs, tablename, row_list, fields): + """Do a COPY FROM STDIN using list of lists as source.""" + + if len(row_list) == 0: + return + + if fields == None or len(fields) == 0: + raise Exception('Need field list') + + buf = StringIO() + for dat in row_list: + row = [] + for i in range(len(fields)): + row.append(quote_copy(dat[i])) + buf.write("\t".join(row)) + buf.write("\n") + + buf.seek(0) + hdr = "%s (%s)" % (tablename, ",".join(fields)) + + curs.copy_from(buf, hdr) + +# +# Full COPY of table from one db to another +# + +class CopyPipe(object): + "Splits one big COPY to chunks." + + def __init__(self, dstcurs, tablename, limit = 512*1024, cancel_func=None): + self.tablename = tablename + self.dstcurs = dstcurs + self.buf = StringIO() + self.limit = limit + self.cancel_func = None + self.total_rows = 0 + self.total_bytes = 0 + + def write(self, data): + "New data from psycopg" + + self.total_bytes += len(data) + self.total_rows += data.count("\n") + + if self.buf.tell() >= self.limit: + pos = data.find('\n') + if pos >= 0: + # split at newline + p1 = data[:pos + 1] + p2 = data[pos + 1:] + self.buf.write(p1) + self.flush() + + data = p2 + + self.buf.write(data) + + def flush(self): + "Send data out." + + if self.cancel_func: + self.cancel_func() + + if self.buf.tell() > 0: + self.buf.seek(0) + self.dstcurs.copy_from(self.buf, self.tablename) + self.buf.seek(0) + self.buf.truncate() + +def full_copy(tablename, src_curs, dst_curs, column_list = []): + """COPY table from one db to another.""" + + if column_list: + hdr = "%s (%s)" % (tablename, ",".join(column_list)) + else: + hdr = tablename + buf = CopyPipe(dst_curs, hdr) + src_curs.copy_to(buf, hdr) + buf.flush() + + return (buf.total_bytes, buf.total_rows) + + +# +# SQL installer +# + +class DBObject(object): + """Base class for installable DB objects.""" + name = None + sql = None + sql_file = None + def __init__(self, name, sql = None, sql_file = None): + self.name = name + self.sql = sql + self.sql_file = sql_file + def get_sql(self): + if self.sql: + return self.sql + if self.sql_file: + if self.sql_file[0] == "/": + fn = self.sql_file + else: + contrib_list = [ + "/opt/pgsql/share/contrib", + "/usr/share/postgresql/8.0/contrib", + "/usr/share/postgresql/8.0/contrib", + "/usr/share/postgresql/8.1/contrib", + "/usr/share/postgresql/8.2/contrib", + ] + for dir in contrib_list: + fn = os.path.join(dir, self.sql_file) + if os.path.isfile(fn): + return open(fn, "r").read() + raise Exception('File not found: '+self.sql_file) + raise Exception('object not defined') + def create(self, curs): + curs.execute(self.get_sql()) + +class DBSchema(DBObject): + """Handles db schema.""" + def exists(self, curs): + return exists_schema(curs, self.name) + +class DBTable(DBObject): + """Handles db table.""" + def exists(self, curs): + return exists_table(curs, self.name) + +class DBFunction(DBObject): + """Handles db function.""" + def __init__(self, name, nargs, sql = None, sql_file = None): + DBObject.__init__(self, name, sql, sql_file) + self.nargs = nargs + def exists(self, curs): + return exists_function(curs, self.name, self.nargs) + +class DBLanguage(DBObject): + """Handles db language.""" + def __init__(self, name): + DBObject.__init__(self, name, sql = "create language %s" % name) + def exists(self, curs): + return exists_language(curs, self.name) + +def db_install(curs, list, log = None): + """Installs list of objects into db.""" + for obj in list: + if not obj.exists(curs): + if log: + log.info('Installing %s' % obj.name) + obj.create(curs) + else: + if log: + log.info('%s is installed' % obj.name) + diff --git a/python/walmgr.py b/python/walmgr.py new file mode 100755 index 00000000..8f43fd6d --- /dev/null +++ b/python/walmgr.py @@ -0,0 +1,648 @@ +#! /usr/bin/env python + +"""WALShipping manager. + +walmgr [-n] COMMAND + +Master commands: + setup Configure PostgreSQL for WAL archiving + backup Copies all master data to slave + sync Copies in-progress WALs to slave + syncdaemon Daemon mode for regular syncing + stop Stop archiving - de-configure PostgreSQL + +Slave commands: + restore Stop postmaster, move new data dir to right + location and start postmaster in playback mode. + boot Stop playback, accept queries. + pause Just wait, don't play WAL-s + continue Start playing WAL-s again + +Internal commands: + xarchive archive one WAL file (master) + xrestore restore one WAL file (slave) + +Switches: + -n no action, just print commands +""" + +import os, sys, skytools, getopt, re, signal, time, traceback + +MASTER = 1 +SLAVE = 0 + +def usage(err): + if err > 0: + print >>sys.stderr, __doc__ + else: + print __doc__ + sys.exit(err) + +class WalMgr(skytools.DBScript): + def __init__(self, wtype, cf_file, not_really, internal = 0, go_daemon = 0): + self.not_really = not_really + self.pg_backup = 0 + + if wtype == MASTER: + service_name = "wal-master" + else: + service_name = "wal-slave" + + if not os.path.isfile(cf_file): + print "Config not found:", cf_file + sys.exit(1) + + if go_daemon: + s_args = ["-d", cf_file] + else: + s_args = [cf_file] + + skytools.DBScript.__init__(self, service_name, s_args, + force_logfile = internal) + + def pg_start_backup(self, code): + q = "select pg_start_backup('FullBackup')" + self.log.info("Execute SQL: %s; [%s]" % (q, self.cf.get("master_db"))) + if self.not_really: + self.pg_backup = 1 + return + db = self.get_database("master_db") + db.cursor().execute(q) + db.commit() + self.close_database("master_db") + self.pg_backup = 1 + + def pg_stop_backup(self): + if not self.pg_backup: + return + + q = "select pg_stop_backup()" + self.log.debug("Execute SQL: %s; [%s]" % (q, self.cf.get("master_db"))) + if self.not_really: + return + db = self.get_database("master_db") + db.cursor().execute(q) + db.commit() + self.close_database("master_db") + + def signal_postmaster(self, data_dir, sgn): + pidfile = os.path.join(data_dir, "postmaster.pid") + if not os.path.isfile(pidfile): + self.log.info("postmaster is not running") + return + buf = open(pidfile, "r").readline() + pid = int(buf.strip()) + self.log.debug("Signal %d to process %d" % (sgn, pid)) + if not self.not_really: + os.kill(pid, sgn) + + def exec_big_rsync(self, cmdline): + cmd = "' '".join(cmdline) + self.log.debug("Execute big rsync cmd: '%s'" % (cmd)) + if self.not_really: + return + res = os.spawnvp(os.P_WAIT, cmdline[0], cmdline) + if res == 24: + self.log.info("Some files vanished, but thats OK") + elif res != 0: + self.log.fatal("exec failed, res=%d" % res) + self.pg_stop_backup() + sys.exit(1) + + def exec_cmd(self, cmdline): + cmd = "' '".join(cmdline) + self.log.debug("Execute cmd: '%s'" % (cmd)) + if self.not_really: + return + res = os.spawnvp(os.P_WAIT, cmdline[0], cmdline) + if res != 0: + self.log.fatal("exec failed, res=%d" % res) + sys.exit(1) + + def chdir(self, loc): + self.log.debug("chdir: '%s'" % (loc)) + if self.not_really: + return + try: + os.chdir(loc) + except os.error: + self.log.fatal("CHDir failed") + self.pg_stop_backup() + sys.exit(1) + + def get_last_complete(self): + """Get the name of last xarchived segment.""" + + data_dir = self.cf.get("master_data") + fn = os.path.join(data_dir, ".walshipping.last") + try: + last = open(fn, "r").read().strip() + return last + except: + self.log.info("Failed to read %s" % fn) + return None + + def set_last_complete(self, last): + """Set the name of last xarchived segment.""" + + data_dir = self.cf.get("master_data") + fn = os.path.join(data_dir, ".walshipping.last") + fn_tmp = fn + ".new" + try: + f = open(fn_tmp, "w") + f.write(last) + f.close() + os.rename(fn_tmp, fn) + except: + self.log.fatal("Cannot write to %s" % fn) + + def master_setup(self): + self.log.info("Configuring WAL archiving") + + script = os.path.abspath(sys.argv[0]) + cf_file = os.path.abspath(self.cf.filename) + cf_val = "%s %s %s" % (script, cf_file, "xarchive %p %f") + + self.master_configure_archiving(cf_val) + + def master_stop(self): + self.log.info("Disabling WAL archiving") + + self.master_configure_archiving('') + + def master_configure_archiving(self, cf_val): + cf_file = self.cf.get("master_config") + data_dir = self.cf.get("master_data") + r_active = re.compile("^[ ]*archive_command[ ]*=[ ]*'(.*)'.*$", re.M) + r_disabled = re.compile("^.*archive_command.*$", re.M) + + cf_full = "archive_command = '%s'" % cf_val + + if not os.path.isfile(cf_file): + self.log.fatal("Config file not found: %s" % cf_file) + self.log.info("Using config file: %s", cf_file) + + buf = open(cf_file, "r").read() + m = r_active.search(buf) + if m: + old_val = m.group(1) + if old_val == cf_val: + self.log.debug("postmaster already configured") + else: + self.log.debug("found active but different conf") + newbuf = "%s%s%s" % (buf[:m.start()], cf_full, buf[m.end():]) + self.change_config(cf_file, newbuf) + else: + m = r_disabled.search(buf) + if m: + self.log.debug("found disabled value") + newbuf = "%s\n%s%s" % (buf[:m.end()], cf_full, buf[m.end():]) + self.change_config(cf_file, newbuf) + else: + self.log.debug("found no value") + newbuf = "%s\n%s\n\n" % (buf, cf_full) + self.change_config(cf_file, newbuf) + + self.log.info("Sending SIGHUP to postmaster") + self.signal_postmaster(data_dir, signal.SIGHUP) + self.log.info("Done") + + def change_config(self, cf_file, buf): + cf_old = cf_file + ".old" + cf_new = cf_file + ".new" + + if self.not_really: + cf_new = "/tmp/postgresql.conf.new" + open(cf_new, "w").write(buf) + self.log.info("Showing diff") + os.system("diff -u %s %s" % (cf_file, cf_new)) + self.log.info("Done diff") + os.remove(cf_new) + return + + # polite method does not work, as usually not enough perms for it + if 0: + open(cf_new, "w").write(buf) + bak = open(cf_file, "r").read() + open(cf_old, "w").write(bak) + os.rename(cf_new, cf_file) + else: + open(cf_file, "w").write(buf) + + def remote_mkdir(self, remdir): + tmp = remdir.split(":", 1) + if len(tmp) != 2: + raise Exception("cannot find hostname") + host, path = tmp + cmdline = ["ssh", host, "mkdir", "-p", path] + self.exec_cmd(cmdline) + + def master_backup(self): + """Copy master data directory to slave.""" + + data_dir = self.cf.get("master_data") + dst_loc = self.cf.get("full_backup") + if dst_loc[-1] != "/": + dst_loc += "/" + + self.pg_start_backup("FullBackup") + + master_spc_dir = os.path.join(data_dir, "pg_tblspc") + slave_spc_dir = dst_loc + "tmpspc" + + # copy data + self.chdir(data_dir) + cmdline = ["rsync", "-a", "--delete", + "--exclude", ".*", + "--exclude", "*.pid", + "--exclude", "*.opts", + "--exclude", "*.conf", + "--exclude", "*.conf.*", + "--exclude", "pg_xlog", + ".", dst_loc] + self.exec_big_rsync(cmdline) + + # copy tblspc first, to test + if os.path.isdir(master_spc_dir): + self.log.info("Checking tablespaces") + list = os.listdir(master_spc_dir) + if len(list) > 0: + self.remote_mkdir(slave_spc_dir) + for tblspc in list: + if tblspc[0] == ".": + continue + tfn = os.path.join(master_spc_dir, tblspc) + if not os.path.islink(tfn): + self.log.info("Suspicious pg_tblspc entry: "+tblspc) + continue + spc_path = os.path.realpath(tfn) + self.log.info("Got tablespace %s: %s" % (tblspc, spc_path)) + dstfn = slave_spc_dir + "/" + tblspc + + try: + os.chdir(spc_path) + except Exception, det: + self.log.warning("Broken link:" + str(det)) + continue + cmdline = ["rsync", "-a", "--delete", + "--exclude", ".*", + ".", dstfn] + self.exec_big_rsync(cmdline) + + # copy pg_xlog + self.chdir(data_dir) + cmdline = ["rsync", "-a", + "--exclude", "*.done", + "--exclude", "*.backup", + "--delete", "pg_xlog", dst_loc] + self.exec_big_rsync(cmdline) + + self.pg_stop_backup() + + self.log.info("Full backup successful") + + def master_xarchive(self, srcpath, srcname): + """Copy a complete WAL segment to slave.""" + + start_time = time.time() + self.log.debug("%s: start copy", srcname) + + self.set_last_complete(srcname) + + dst_loc = self.cf.get("completed_wals") + if dst_loc[-1] != "/": + dst_loc += "/" + + # copy data + cmdline = ["rsync", "-t", srcpath, dst_loc] + self.exec_cmd(cmdline) + + self.log.debug("%s: done", srcname) + end_time = time.time() + self.stat_add('count', 1) + self.stat_add('duration', end_time - start_time) + + def master_sync(self): + """Copy partial WAL segments.""" + + data_dir = self.cf.get("master_data") + xlog_dir = os.path.join(data_dir, "pg_xlog") + dst_loc = self.cf.get("partial_wals") + if dst_loc[-1] != "/": + dst_loc += "/" + + files = os.listdir(xlog_dir) + files.sort() + + last = self.get_last_complete() + if last: + self.log.info("%s: last complete" % last) + else: + self.log.info("last complete not found, copying all") + + for fn in files: + # check if interesting file + if len(fn) < 10: + continue + if fn[0] < "0" or fn[0] > '9': + continue + if fn.find(".") > 0: + continue + # check if to old + if last: + dot = last.find(".") + if dot > 0: + xlast = last[:dot] + if fn < xlast: + continue + else: + if fn <= last: + continue + + # got interesting WAL + xlog = os.path.join(xlog_dir, fn) + # copy data + cmdline = ["rsync", "-t", xlog, dst_loc] + self.exec_cmd(cmdline) + + self.log.info("Partial copy done") + + def slave_xrestore(self, srcname, dstpath): + loop = 1 + while loop: + try: + self.slave_xrestore_unsafe(srcname, dstpath) + loop = 0 + except SystemExit, d: + sys.exit(1) + except Exception, d: + exc, msg, tb = sys.exc_info() + self.log.fatal("xrestore %s crashed: %s: '%s' (%s: %s)" % ( + srcname, str(exc), str(msg).rstrip(), + str(tb), repr(traceback.format_tb(tb)))) + time.sleep(10) + self.log.info("Re-exec: %s", repr(sys.argv)) + os.execv(sys.argv[0], sys.argv) + + def slave_xrestore_unsafe(self, srcname, dstpath): + srcdir = self.cf.get("completed_wals") + partdir = self.cf.get("partial_wals") + keep_old_logs = self.cf.getint("keep_old_logs", 0) + pausefile = os.path.join(srcdir, "PAUSE") + stopfile = os.path.join(srcdir, "STOP") + srcfile = os.path.join(srcdir, srcname) + partfile = os.path.join(partdir, srcname) + + # loop until srcfile or stopfile appears + while 1: + if os.path.isfile(pausefile): + self.log.info("pause requested, sleeping") + time.sleep(20) + continue + + if os.path.isfile(srcfile): + self.log.info("%s: Found" % srcname) + break + + # ignore .history files + unused, ext = os.path.splitext(srcname) + if ext == ".history": + self.log.info("%s: not found, ignoring" % srcname) + sys.exit(1) + + # if stopping, include also partial wals + if os.path.isfile(stopfile): + if os.path.isfile(partfile): + self.log.info("%s: found partial" % srcname) + srcfile = partfile + break + else: + self.log.info("%s: not found, stopping" % srcname) + sys.exit(1) + + # nothing to do, sleep + self.log.debug("%s: not found, sleeping" % srcname) + time.sleep(20) + + # got one, copy it + cmdline = ["cp", srcfile, dstpath] + self.exec_cmd(cmdline) + + self.log.debug("%s: copy done, cleanup" % srcname) + self.slave_cleanup(srcname) + + # it would be nice to have apply time too + self.stat_add('count', 1) + + def slave_startup(self): + data_dir = self.cf.get("slave_data") + full_dir = self.cf.get("full_backup") + stop_cmd = self.cf.get("slave_stop_cmd", "") + start_cmd = self.cf.get("slave_start_cmd") + pidfile = os.path.join(data_dir, "postmaster.pid") + + # stop postmaster if ordered + if stop_cmd and os.path.isfile(pidfile): + self.log.info("Stopping postmaster: " + stop_cmd) + if not self.not_really: + os.system(stop_cmd) + time.sleep(3) + + # is it dead? + if os.path.isfile(pidfile): + self.log.fatal("Postmaster still running. Cannot continue.") + sys.exit(1) + + # find name for data backup + i = 0 + while 1: + bak = "%s.%d" % (data_dir, i) + if not os.path.isdir(bak): + break + i += 1 + + # move old data away + if os.path.isdir(data_dir): + self.log.info("Move %s to %s" % (data_dir, bak)) + if not self.not_really: + os.rename(data_dir, bak) + + # move new data + self.log.info("Move %s to %s" % (full_dir, data_dir)) + if not self.not_really: + os.rename(full_dir, data_dir) + else: + data_dir = full_dir + + # re-link tablespaces + spc_dir = os.path.join(data_dir, "pg_tblspc") + tmp_dir = os.path.join(data_dir, "tmpspc") + if os.path.isdir(spc_dir) and os.path.isdir(tmp_dir): + self.log.info("Linking tablespaces to temporary location") + + # don't look into spc_dir, thus allowing + # user to move them before. re-link only those + # that are still in tmp_dir + list = os.listdir(tmp_dir) + list.sort() + + for d in list: + if d[0] == ".": + continue + link_loc = os.path.join(spc_dir, d) + link_dst = os.path.join(tmp_dir, d) + self.log.info("Linking tablespace %s to %s" % (d, link_dst)) + if not self.not_really: + if os.path.islink(link_loc): + os.remove(link_loc) + os.symlink(link_dst, link_loc) + + # write recovery.conf + rconf = os.path.join(data_dir, "recovery.conf") + script = os.path.abspath(sys.argv[0]) + cf_file = os.path.abspath(self.cf.filename) + conf = "\nrestore_command = '%s %s %s'\n" % ( + script, cf_file, 'xrestore %f "%p"') + self.log.info("Write %s" % rconf) + if self.not_really: + print conf + else: + f = open(rconf, "w") + f.write(conf) + f.close() + + # remove stopfile + srcdir = self.cf.get("completed_wals") + stopfile = os.path.join(srcdir, "STOP") + if os.path.isfile(stopfile): + self.log.info("Removing stopfile: "+stopfile) + if not self.not_really: + os.remove(stopfile) + + # run database in recovery mode + self.log.info("Starting postmaster: " + start_cmd) + if not self.not_really: + os.system(start_cmd) + + def slave_boot(self): + srcdir = self.cf.get("completed_wals") + stopfile = os.path.join(srcdir, "STOP") + open(stopfile, "w").write("1") + self.log.info("Stopping recovery mode") + + def slave_pause(self): + srcdir = self.cf.get("completed_wals") + pausefile = os.path.join(srcdir, "PAUSE") + open(pausefile, "w").write("1") + self.log.info("Pausing recovery mode") + + def slave_continue(self): + srcdir = self.cf.get("completed_wals") + pausefile = os.path.join(srcdir, "PAUSE") + if os.path.isfile(pausefile): + os.remove(pausefile) + self.log.info("Continuing with recovery") + else: + self.log.info("Recovery not paused?") + + def slave_cleanup(self, last_applied): + completed_wals = self.cf.get("completed_wals") + partial_wals = self.cf.get("partial_wals") + + self.log.debug("cleaning completed wals since %s" % last_applied) + last = self.del_wals(completed_wals, last_applied) + if last: + if os.path.isdir(partial_wals): + self.log.debug("cleaning partial wals since %s" % last) + self.del_wals(partial_wals, last) + else: + self.log.warning("partial_wals dir does not exist: %s" + % partial_wals) + self.log.debug("cleaning done") + + def del_wals(self, path, last): + dot = last.find(".") + if dot > 0: + last = last[:dot] + list = os.listdir(path) + list.sort() + cur_last = None + n = len(list) + for i in range(n): + fname = list[i] + full = os.path.join(path, fname) + if fname[0] < "0" or fname[0] > "9": + continue + + ok_del = 0 + if fname < last: + self.log.debug("deleting %s" % full) + os.remove(full) + cur_last = fname + return cur_last + + def work(self): + self.master_sync() + +def main(): + try: + opts, args = getopt.getopt(sys.argv[1:], "nh") + except getopt.error, det: + print det + usage(1) + not_really = 0 + for o, v in opts: + if o == "-n": + not_really = 1 + elif o == "-h": + usage(0) + if len(args) < 2: + usage(1) + ini = args[0] + cmd = args[1] + + if cmd == "setup": + script = WalMgr(MASTER, ini, not_really) + script.master_setup() + elif cmd == "stop": + script = WalMgr(MASTER, ini, not_really) + script.master_stop() + elif cmd == "backup": + script = WalMgr(MASTER, ini, not_really) + script.master_backup() + elif cmd == "xarchive": + if len(args) != 4: + print >> sys.stderr, "usage: walmgr INI xarchive %p %f" + sys.exit(1) + script = WalMgr(MASTER, ini, not_really, 1) + script.master_xarchive(args[2], args[3]) + elif cmd == "sync": + script = WalMgr(MASTER, ini, not_really) + script.master_sync() + elif cmd == "syncdaemon": + script = WalMgr(MASTER, ini, not_really, go_daemon=1) + script.start() + elif cmd == "xrestore": + if len(args) != 4: + print >> sys.stderr, "usage: walmgr INI xrestore %p %f" + sys.exit(1) + script = WalMgr(SLAVE, ini, not_really, 1) + script.slave_xrestore(args[2], args[3]) + elif cmd == "restore": + script = WalMgr(SLAVE, ini, not_really) + script.slave_startup() + elif cmd == "boot": + script = WalMgr(SLAVE, ini, not_really) + script.slave_boot() + elif cmd == "pause": + script = WalMgr(SLAVE, ini, not_really) + script.slave_pause() + elif cmd == "continue": + script = WalMgr(SLAVE, ini, not_really) + script.slave_continue() + else: + usage(1) + +if __name__ == '__main__': + main() + |