summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/conf/londiste.ini16
-rw-r--r--python/conf/pgqadm.ini18
-rw-r--r--python/conf/skylog.ini76
-rw-r--r--python/conf/wal-master.ini18
-rw-r--r--python/conf/wal-slave.ini15
-rwxr-xr-xpython/londiste.py130
-rw-r--r--python/londiste/__init__.py12
-rw-r--r--python/londiste/compare.py45
-rw-r--r--python/londiste/file_read.py52
-rw-r--r--python/londiste/file_write.py67
-rw-r--r--python/londiste/installer.py26
-rw-r--r--python/londiste/playback.py558
-rw-r--r--python/londiste/repair.py284
-rw-r--r--python/londiste/setup.py580
-rw-r--r--python/londiste/syncer.py177
-rw-r--r--python/londiste/table_copy.py107
-rw-r--r--python/pgq/__init__.py6
-rw-r--r--python/pgq/consumer.py410
-rw-r--r--python/pgq/event.py60
-rw-r--r--python/pgq/maint.py99
-rw-r--r--python/pgq/producer.py41
-rw-r--r--python/pgq/status.py93
-rw-r--r--python/pgq/ticker.py172
-rwxr-xr-xpython/pgqadm.py162
-rw-r--r--python/skytools/__init__.py10
-rw-r--r--python/skytools/config.py139
-rw-r--r--python/skytools/dbstruct.py380
-rw-r--r--python/skytools/gzlog.py39
-rw-r--r--python/skytools/quoting.py156
-rw-r--r--python/skytools/scripting.py523
-rw-r--r--python/skytools/skylog.py173
-rw-r--r--python/skytools/sqltools.py398
-rwxr-xr-xpython/walmgr.py648
33 files changed, 5690 insertions, 0 deletions
diff --git a/python/conf/londiste.ini b/python/conf/londiste.ini
new file mode 100644
index 00000000..a1506a32
--- /dev/null
+++ b/python/conf/londiste.ini
@@ -0,0 +1,16 @@
+
+[londiste]
+job_name = test_to_subcriber
+
+provider_db = dbname=provider port=6000 host=127.0.0.1
+subscriber_db = dbname=subscriber port=6000 host=127.0.0.1
+
+# it will be used as sql ident so no dots/spaces
+pgq_queue_name = londiste.replika
+
+logfile = ~/log/%(job_name)s.log
+pidfile = ~/pid/%(job_name)s.pid
+
+# both events and ticks will be copied there
+#mirror_queue = replika_mirror
+
diff --git a/python/conf/pgqadm.ini b/python/conf/pgqadm.ini
new file mode 100644
index 00000000..a2e92f6b
--- /dev/null
+++ b/python/conf/pgqadm.ini
@@ -0,0 +1,18 @@
+
+[pgqadm]
+
+job_name = pgqadm_somedb
+
+db = dbname=provider port=6000 host=127.0.0.1
+
+# how often to run maintenance [minutes]
+maint_delay_min = 5
+
+# how often to check for activity [secs]
+loop_delay = 0.1
+
+logfile = ~/log/%(job_name)s.log
+pidfile = ~/pid/%(job_name)s.pid
+
+use_skylog = 0
+
diff --git a/python/conf/skylog.ini b/python/conf/skylog.ini
new file mode 100644
index 00000000..150ef934
--- /dev/null
+++ b/python/conf/skylog.ini
@@ -0,0 +1,76 @@
+; notes:
+; - 'args' is mandatory in [handler_*] sections
+; - in lists there must not be spaces
+
+;
+; top-level config
+;
+
+; list of all loggers
+[loggers]
+keys=root
+; root logger sees everything. there can be per-job configs by
+; specifing loggers with job_name of the script
+
+; list of all handlers
+[handlers]
+;; seems logger module immidiately initalized all handlers,
+;; whether they are actually used or not. so better
+;; keep this list in sync with actual handler list
+;keys=stderr,logdb,logsrv,logfile
+keys=stderr
+
+; list of all formatters
+[formatters]
+keys=short,long,none
+
+;
+; map specific loggers to specifig handlers
+;
+[logger_root]
+level=DEBUG
+;handlers=stderr,logdb,logsrv,logfile
+handlers=stderr
+
+;
+; configure formatters
+;
+[formatter_short]
+format=%(asctime)s %(levelname)s %(message)s
+datefmt=%H:%M
+
+[formatter_long]
+format=%(asctime)s %(process)s %(levelname)s %(message)s
+
+[formatter_none]
+format=%(message)s
+
+;
+; configure handlers
+;
+
+; file. args: stream
+[handler_stderr]
+class=StreamHandler
+args=(sys.stderr,)
+formatter=short
+
+; log into db. args: conn_string
+[handler_logdb]
+class=skylog.LogDBHandler
+args=("host=127.0.0.1 port=5432 user=logger dbname=logdb",)
+formatter=none
+level=INFO
+
+; JSON messages over UDP. args: host, port
+[handler_logsrv]
+class=skylog.UdpLogServerHandler
+args=('127.0.0.1', 6666)
+formatter=none
+
+; rotating logfile. args: filename, maxsize, maxcount
+[handler_logfile]
+class=skylog.EasyRotatingFileHandler
+args=('~/log/%(job_name)s.log', 100*1024*1024, 3)
+formatter=long
+
diff --git a/python/conf/wal-master.ini b/python/conf/wal-master.ini
new file mode 100644
index 00000000..5ae8cb2b
--- /dev/null
+++ b/python/conf/wal-master.ini
@@ -0,0 +1,18 @@
+[wal-master]
+logfile = master.log
+use_skylog = 0
+
+master_db = dbname=template1
+master_data = /var/lib/postgresql/8.0/main
+master_config = /etc/postgresql/8.0/main/postgresql.conf
+
+
+slave = slave:/var/lib/postgresql/walshipping
+
+completed_wals = %(slave)s/logs.complete
+partial_wals = %(slave)s/logs.partial
+full_backup = %(slave)s/data.master
+
+# syncdaemon update frequency
+loop_delay = 10.0
+
diff --git a/python/conf/wal-slave.ini b/python/conf/wal-slave.ini
new file mode 100644
index 00000000..912bf756
--- /dev/null
+++ b/python/conf/wal-slave.ini
@@ -0,0 +1,15 @@
+[wal-slave]
+logfile = slave.log
+use_skylog = 0
+
+slave_data = /var/lib/postgresql/8.0/main
+slave_stop_cmd = /etc/init.d/postgresql-8.0 stop
+slave_start_cmd = /etc/init.d/postgresql-8.0 start
+
+slave = /var/lib/postgresql/walshipping
+completed_wals = %(slave)s/logs.complete
+partial_wals = %(slave)s/logs.partial
+full_backup = %(slave)s/data.master
+
+keep_old_logs = 0
+
diff --git a/python/londiste.py b/python/londiste.py
new file mode 100755
index 00000000..9e5684ea
--- /dev/null
+++ b/python/londiste.py
@@ -0,0 +1,130 @@
+#! /usr/bin/env python
+
+"""Londiste launcher.
+"""
+
+import sys, os, optparse, skytools
+
+# python 2.3 will try londiste.py first...
+import sys, os.path
+if os.path.exists(os.path.join(sys.path[0], 'londiste.py')) \
+ and not os.path.exists(os.path.join(sys.path[0], 'londiste')):
+ del sys.path[0]
+
+from londiste import *
+
+command_usage = """
+%prog [options] INI CMD [subcmd args]
+
+commands:
+ provider install installs modules, creates queue
+ provider add TBL ... add table to queue
+ provider remove TBL ... remove table from queue
+ provider tables show all tables linked to queue
+ provider seqs show all sequences on provider
+
+ subscriber install installs schema
+ subscriber add TBL ... add table to subscriber
+ subscriber remove TBL ... remove table from subscriber
+ subscriber tables list tables subscriber has attached to
+ subscriber seqs list sequences subscriber is interested
+ subscriber missing list tables subscriber has not yet attached to
+ subscriber link QUE create mirror queue
+ subscriber unlink dont mirror queue
+
+ replay replay events to subscriber
+
+ switchover switch the roles between provider & subscriber
+ compare [TBL ...] compare table contents on both sides
+ repair [TBL ...] repair data on subscriber
+ copy full copy of table, internal cmd
+ subscriber check compare table structure on both sides
+ subscriber fkeys print out fkey drop/create commands
+ subscriber resync TBL ... do full copy again
+ subscriber register attaches subscriber to queue (also done by replay)
+ subscriber unregister detach subscriber from queue
+"""
+
+"""switchover:
+goal is to launch a replay with reverse config
+
+- should link be required? (link should guarantee all tables/seqs same?)
+- should link auto-add tables for subscriber
+
+1. lock all tables on provider, in order specified by 'nr'
+2. wait until old replay is past the point
+3. sync seq
+4. replace queue triggers on provider with deny triggers
+5. replace deny triggers on subscriber with queue triggers
+6. sync pgq tick seqs? change pgq config?
+
+"""
+
+class Londiste(skytools.DBScript):
+ def __init__(self, args):
+ skytools.DBScript.__init__(self, 'londiste', args)
+
+ if self.options.rewind or self.options.reset:
+ self.script = Replicator(args)
+ return
+
+ if len(self.args) < 2:
+ print "need command"
+ sys.exit(1)
+ cmd = self.args[1]
+
+ if cmd =="provider":
+ script = ProviderSetup(args)
+ elif cmd == "subscriber":
+ script = SubscriberSetup(args)
+ elif cmd == "replay":
+ method = self.cf.get('method', 'direct')
+ if method == 'direct':
+ script = Replicator(args)
+ elif method == 'file_write':
+ script = FileWrite(args)
+ elif method == 'file_write':
+ script = FileWrite(args)
+ else:
+ print "unknown method, quitting"
+ sys.exit(1)
+ elif cmd == "copy":
+ script = CopyTable(args)
+ elif cmd == "compare":
+ script = Comparator(args)
+ elif cmd == "repair":
+ script = Repairer(args)
+ elif cmd == "upgrade":
+ script = UpgradeV2(args)
+ else:
+ print "Unknown command '%s', use --help for help" % cmd
+ sys.exit(1)
+
+ self.script = script
+
+ def start(self):
+ self.script.start()
+
+ def init_optparse(self, parser=None):
+ p = skytools.DBScript.init_optparse(self, parser)
+ p.set_usage(command_usage.strip())
+
+ g = optparse.OptionGroup(p, "expert options")
+ g.add_option("--force", action="store_true",
+ help = "add: ignore table differences, repair: ignore lag")
+ g.add_option("--expect-sync", action="store_true", dest="expect_sync",
+ help = "add: no copy needed", default=False)
+ g.add_option("--skip-truncate", action="store_true", dest="skip_truncate",
+ help = "copy: keep old data", default=False)
+ g.add_option("--rewind", action="store_true",
+ help = "replay: sync queue pos with subscriber")
+ g.add_option("--reset", action="store_true",
+ help = "replay: forget queue pos on subscriber")
+ p.add_option_group(g)
+
+ return p
+
+if __name__ == '__main__':
+ script = Londiste(sys.argv[1:])
+ script.start()
+
diff --git a/python/londiste/__init__.py b/python/londiste/__init__.py
new file mode 100644
index 00000000..97d67433
--- /dev/null
+++ b/python/londiste/__init__.py
@@ -0,0 +1,12 @@
+
+"""Replication on top of PgQ."""
+
+from playback import *
+from compare import *
+from file_read import *
+from file_write import *
+from setup import *
+from table_copy import *
+from installer import *
+from repair import *
+
diff --git a/python/londiste/compare.py b/python/londiste/compare.py
new file mode 100644
index 00000000..0029665b
--- /dev/null
+++ b/python/londiste/compare.py
@@ -0,0 +1,45 @@
+#! /usr/bin/env python
+
+"""Compares tables in replication set.
+
+Currently just does count(1) on both sides.
+"""
+
+import sys, os, time, skytools
+
+__all__ = ['Comparator']
+
+from syncer import Syncer
+
+class Comparator(Syncer):
+ def process_sync(self, tbl, src_db, dst_db):
+ """Actual comparision."""
+
+ src_curs = src_db.cursor()
+ dst_curs = dst_db.cursor()
+
+ self.log.info('Counting %s' % tbl)
+
+ q = "select count(1) from only _TABLE_"
+ q = self.cf.get('compare_sql', q)
+ q = q.replace('_TABLE_', tbl)
+
+ self.log.debug("srcdb: " + q)
+ src_curs.execute(q)
+ src_row = src_curs.fetchone()
+ src_str = ", ".join(map(str, src_row))
+ self.log.info("srcdb: res = %s" % src_str)
+
+ self.log.debug("dstdb: " + q)
+ dst_curs.execute(q)
+ dst_row = dst_curs.fetchone()
+ dst_str = ", ".join(map(str, dst_row))
+ self.log.info("dstdb: res = %s" % dst_str)
+
+ if src_str != dst_str:
+ self.log.warning("%s: Results do not match!" % tbl)
+
+if __name__ == '__main__':
+ script = Comparator(sys.argv[1:])
+ script.start()
+
diff --git a/python/londiste/file_read.py b/python/londiste/file_read.py
new file mode 100644
index 00000000..2902bda5
--- /dev/null
+++ b/python/londiste/file_read.py
@@ -0,0 +1,52 @@
+
+"""Reads events from file instead of db queue."""
+
+import sys, os, re, skytools
+
+from playback import *
+from table_copy import *
+
+__all__ = ['FileRead']
+
+file_regex = r"^tick_0*([0-9]+)\.sql$"
+file_rc = re.compile(file_regex)
+
+
+class FileRead(CopyTable):
+ """Reads events from file instead of db queue.
+
+ Incomplete implementation.
+ """
+
+ def __init__(self, args, log = None):
+ CopyTable.__init__(self, args, log, copy_thread = 0)
+
+ def launch_copy(self, tbl):
+ # copy immidiately
+ self.do_copy(t)
+
+ def work(self):
+ last_batch = self.get_last_batch(curs)
+ list = self.get_file_list()
+
+ def get_list(self):
+ """Return list of (first_batch, full_filename) pairs."""
+
+ src_dir = self.cf.get('file_src')
+ list = os.listdir(src_dir)
+ list.sort()
+ res = []
+ for fn in list:
+ m = file_rc.match(fn)
+ if not m:
+ self.log.debug("Ignoring file: %s" % fn)
+ continue
+ full = os.path.join(src_dir, fn)
+ batch_id = int(m.group(1))
+ res.append((batch_id, full))
+ return res
+
+if __name__ == '__main__':
+ script = Replicator(sys.argv[1:])
+ script.start()
+
diff --git a/python/londiste/file_write.py b/python/londiste/file_write.py
new file mode 100644
index 00000000..86e16aae
--- /dev/null
+++ b/python/londiste/file_write.py
@@ -0,0 +1,67 @@
+
+"""Writes events into file."""
+
+import sys, os, skytools
+from cStringIO import StringIO
+from playback import *
+
+__all__ = ['FileWrite']
+
+class FileWrite(Replicator):
+ """Writes events into file.
+
+ Incomplete implementation.
+ """
+
+ last_successful_batch = None
+
+ def load_state(self, batch_id):
+ # maybe check if batch exists on filesystem?
+ self.cur_tick = self.cur_batch_info['tick_id']
+ self.prev_tick = self.cur_batch_info['prev_tick_id']
+ return 1
+
+ def process_batch(self, db, batch_id, ev_list):
+ pass
+
+ def save_state(self, do_commit):
+ # nothing to save
+ pass
+
+ def sync_tables(self, dst_db):
+ # nothing to sync
+ return 1
+
+ def interesting(self, ev):
+ # wants all of them
+ return 1
+
+ def handle_data_event(self, ev):
+ fmt = self.sql_command[ev.type]
+ sql = fmt % (ev.ev_extra1, ev.data)
+ row = "%s -- txid:%d" % (sql, ev.txid)
+ self.sql_list.append(row)
+ ev.tag_done()
+
+ def handle_system_event(self, ev):
+ row = "-- sysevent:%s txid:%d data:%s" % (
+ ev.type, ev.txid, ev.data)
+ self.sql_list.append(row)
+ ev.tag_done()
+
+ def flush_sql(self):
+ self.sql_list.insert(0, "-- tick:%d prev:%s" % (
+ self.cur_tick, self.prev_tick))
+ self.sql_list.append("-- end_tick:%d\n" % self.cur_tick)
+ # store result
+ dir = self.cf.get("file_dst")
+ fn = os.path.join(dir, "tick_%010d.sql" % self.cur_tick)
+ f = open(fn, "w")
+ buf = "\n".join(self.sql_list)
+ f.write(buf)
+ f.close()
+
+if __name__ == '__main__':
+ script = Replicator(sys.argv[1:])
+ script.start()
+
diff --git a/python/londiste/installer.py b/python/londiste/installer.py
new file mode 100644
index 00000000..6f190ab2
--- /dev/null
+++ b/python/londiste/installer.py
@@ -0,0 +1,26 @@
+
+"""Functions to install londiste and its depentencies into database."""
+
+import os, skytools
+
+__all__ = ['install_provider', 'install_subscriber']
+
+provider_object_list = [
+ skytools.DBFunction('logtriga', 0, sql_file = "logtriga.sql"),
+ skytools.DBFunction('get_current_snapshot', 0, sql_file = "txid.sql"),
+ skytools.DBSchema('pgq', sql_file = "pgq.sql"),
+ skytools.DBSchema('londiste', sql_file = "londiste.sql")
+]
+
+subscriber_object_list = [
+ skytools.DBSchema('londiste', sql_file = "londiste.sql")
+]
+
+def install_provider(curs, log):
+ """Installs needed code into provider db."""
+ skytools.db_install(curs, provider_object_list, log)
+
+def install_subscriber(curs, log):
+ """Installs needed code into subscriber db."""
+ skytools.db_install(curs, subscriber_object_list, log)
+
diff --git a/python/londiste/playback.py b/python/londiste/playback.py
new file mode 100644
index 00000000..2bcb1bc7
--- /dev/null
+++ b/python/londiste/playback.py
@@ -0,0 +1,558 @@
+#! /usr/bin/env python
+
+"""Basic replication core."""
+
+import sys, os, time
+import skytools, pgq
+
+__all__ = ['Replicator', 'TableState',
+ 'TABLE_MISSING', 'TABLE_IN_COPY', 'TABLE_CATCHING_UP',
+ 'TABLE_WANNA_SYNC', 'TABLE_DO_SYNC', 'TABLE_OK']
+
+# state # owner - who is allowed to change
+TABLE_MISSING = 0 # main
+TABLE_IN_COPY = 1 # copy
+TABLE_CATCHING_UP = 2 # copy
+TABLE_WANNA_SYNC = 3 # main
+TABLE_DO_SYNC = 4 # copy
+TABLE_OK = 5 # setup
+
+SYNC_OK = 0 # continue with batch
+SYNC_LOOP = 1 # sleep, try again
+SYNC_EXIT = 2 # nothing to do, exit skript
+
+class Counter(object):
+ """Counts table statuses."""
+
+ missing = 0
+ copy = 0
+ catching_up = 0
+ wanna_sync = 0
+ do_sync = 0
+ ok = 0
+
+ def __init__(self, tables):
+ """Counts and sanity checks."""
+ for t in tables:
+ if t.state == TABLE_MISSING:
+ self.missing += 1
+ elif t.state == TABLE_IN_COPY:
+ self.copy += 1
+ elif t.state == TABLE_CATCHING_UP:
+ self.catching_up += 1
+ elif t.state == TABLE_WANNA_SYNC:
+ self.wanna_sync += 1
+ elif t.state == TABLE_DO_SYNC:
+ self.do_sync += 1
+ elif t.state == TABLE_OK:
+ self.ok += 1
+ # only one table is allowed to have in-progress copy
+ if self.copy + self.catching_up + self.wanna_sync + self.do_sync > 1:
+ raise Exception('Bad table state')
+
+class TableState(object):
+ """Keeps state about one table."""
+ def __init__(self, name, log):
+ self.name = name
+ self.log = log
+ self.forget()
+ self.changed = 0
+
+ def forget(self):
+ self.state = TABLE_MISSING
+ self.str_snapshot = None
+ self.from_snapshot = None
+ self.sync_tick_id = None
+ self.ok_batch_count = 0
+ self.last_tick = 0
+ self.changed = 1
+
+ def change_snapshot(self, str_snapshot, tag_changed = 1):
+ if self.str_snapshot == str_snapshot:
+ return
+ self.log.debug("%s: change_snapshot to %s" % (self.name, str_snapshot))
+ self.str_snapshot = str_snapshot
+ if str_snapshot:
+ self.from_snapshot = skytools.Snapshot(str_snapshot)
+ else:
+ self.from_snapshot = None
+
+ if tag_changed:
+ self.ok_batch_count = 0
+ self.last_tick = None
+ self.changed = 1
+
+ def change_state(self, state, tick_id = None):
+ if self.state == state and self.sync_tick_id == tick_id:
+ return
+ self.state = state
+ self.sync_tick_id = tick_id
+ self.changed = 1
+ self.log.debug("%s: change_state to %s" % (self.name,
+ self.render_state()))
+
+ def render_state(self):
+ """Make a string to be stored in db."""
+
+ if self.state == TABLE_MISSING:
+ return None
+ elif self.state == TABLE_IN_COPY:
+ return 'in-copy'
+ elif self.state == TABLE_CATCHING_UP:
+ return 'catching-up'
+ elif self.state == TABLE_WANNA_SYNC:
+ return 'wanna-sync:%d' % self.sync_tick_id
+ elif self.state == TABLE_DO_SYNC:
+ return 'do-sync:%d' % self.sync_tick_id
+ elif self.state == TABLE_OK:
+ return 'ok'
+
+ def parse_state(self, merge_state):
+ """Read state from string."""
+
+ state = -1
+ if merge_state == None:
+ state = TABLE_MISSING
+ elif merge_state == "in-copy":
+ state = TABLE_IN_COPY
+ elif merge_state == "catching-up":
+ state = TABLE_CATCHING_UP
+ elif merge_state == "ok":
+ state = TABLE_OK
+ elif merge_state == "?":
+ state = TABLE_OK
+ else:
+ tmp = merge_state.split(':')
+ if len(tmp) == 2:
+ self.sync_tick_id = int(tmp[1])
+ if tmp[0] == 'wanna-sync':
+ state = TABLE_WANNA_SYNC
+ elif tmp[0] == 'do-sync':
+ state = TABLE_DO_SYNC
+
+ if state < 0:
+ raise Exception("Bad table state: %s" % merge_state)
+
+ return state
+
+ def loaded_state(self, merge_state, str_snapshot):
+ self.log.debug("loaded_state: %s: %s / %s" % (
+ self.name, merge_state, str_snapshot))
+ self.change_snapshot(str_snapshot, 0)
+ self.state = self.parse_state(merge_state)
+ self.changed = 0
+ if merge_state == "?":
+ self.changed = 1
+
+ def interesting(self, ev, tick_id, copy_thread):
+ """Check if table wants this event."""
+
+ if copy_thread:
+ if self.state not in (TABLE_CATCHING_UP, TABLE_DO_SYNC):
+ return False
+ else:
+ if self.state != TABLE_OK:
+ return False
+
+ # if no snapshot tracking, then accept always
+ if not self.from_snapshot:
+ return True
+
+ # uninteresting?
+ if self.from_snapshot.contains(ev.txid):
+ return False
+
+ # after couple interesting batches there no need to check snapshot
+ # as there can be only one partially interesting batch
+ if tick_id != self.last_tick:
+ self.last_tick = tick_id
+ self.ok_batch_count += 1
+
+ # disable batch tracking
+ if self.ok_batch_count > 3:
+ self.change_snapshot(None)
+ return True
+
+class SeqCache(object):
+ def __init__(self):
+ self.seq_list = []
+ self.val_cache = {}
+
+ def set_seq_list(self, seq_list):
+ self.seq_list = seq_list
+ new_cache = {}
+ for seq in seq_list:
+ val = self.val_cache.get(seq)
+ if val:
+ new_cache[seq] = val
+ self.val_cache = new_cache
+
+ def resync(self, src_curs, dst_curs):
+ if len(self.seq_list) == 0:
+ return
+ dat = ".last_value, ".join(self.seq_list)
+ dat += ".last_value"
+ q = "select %s from %s" % (dat, ",".join(self.seq_list))
+ src_curs.execute(q)
+ row = src_curs.fetchone()
+ for i in range(len(self.seq_list)):
+ seq = self.seq_list[i]
+ cur = row[i]
+ old = self.val_cache.get(seq)
+ if old != cur:
+ q = "select setval(%s, %s)"
+ dst_curs.execute(q, [seq, cur])
+ self.val_cache[seq] = cur
+
+class Replicator(pgq.SerialConsumer):
+ """Replication core."""
+
+ sql_command = {
+ 'I': "insert into %s %s;",
+ 'U': "update only %s set %s;",
+ 'D': "delete from only %s where %s;",
+ }
+
+ # batch info
+ cur_tick = 0
+ prev_tick = 0
+
+ def __init__(self, args):
+ pgq.SerialConsumer.__init__(self, 'londiste', 'provider_db', 'subscriber_db', args)
+
+ # tick table in dst for SerialConsumer(). keep londiste stuff under one schema
+ self.dst_completed_table = "londiste.completed"
+
+ self.table_list = []
+ self.table_map = {}
+
+ self.copy_thread = 0
+ self.maint_time = 0
+ self.seq_cache = SeqCache()
+ self.maint_delay = self.cf.getint('maint_delay', 600)
+ self.mirror_queue = self.cf.get('mirror_queue', '')
+
+ def process_remote_batch(self, src_db, batch_id, ev_list, dst_db):
+ "All work for a batch. Entry point from SerialConsumer."
+
+ # this part can play freely with transactions
+
+ dst_curs = dst_db.cursor()
+
+ self.cur_tick = self.cur_batch_info['tick_id']
+ self.prev_tick = self.cur_batch_info['prev_tick_id']
+
+ self.load_table_state(dst_curs)
+ self.sync_tables(dst_db)
+
+ # now the actual event processing happens.
+ # they must be done all in one tx in dst side
+ # and the transaction must be kept open so that
+ # the SerialConsumer can save last tick and commit.
+
+ self.handle_seqs(dst_curs)
+ self.handle_events(dst_curs, ev_list)
+ self.save_table_state(dst_curs)
+
+ def handle_seqs(self, dst_curs):
+ if self.copy_thread:
+ return
+
+ q = "select * from londiste.subscriber_get_seq_list(%s)"
+ dst_curs.execute(q, [self.pgq_queue_name])
+ seq_list = []
+ for row in dst_curs.fetchall():
+ seq_list.append(row[0])
+
+ self.seq_cache.set_seq_list(seq_list)
+
+ src_curs = self.get_database('provider_db').cursor()
+ self.seq_cache.resync(src_curs, dst_curs)
+
+ def sync_tables(self, dst_db):
+ """Table sync loop.
+
+ Calls appropriate handles, which is expected to
+ return one of SYNC_* constants."""
+
+ self.log.debug('Sync tables')
+ while 1:
+ cnt = Counter(self.table_list)
+ if self.copy_thread:
+ res = self.sync_from_copy_thread(cnt, dst_db)
+ else:
+ res = self.sync_from_main_thread(cnt, dst_db)
+
+ if res == SYNC_EXIT:
+ self.log.debug('Sync tables: exit')
+ self.detach()
+ sys.exit(0)
+ elif res == SYNC_OK:
+ return
+ elif res != SYNC_LOOP:
+ raise Exception('Program error')
+
+ self.log.debug('Sync tables: sleeping')
+ time.sleep(3)
+ dst_db.commit()
+ self.load_table_state(dst_db.cursor())
+ dst_db.commit()
+
+ def sync_from_main_thread(self, cnt, dst_db):
+ "Main thread sync logic."
+
+ #
+ # decide what to do - order is imortant
+ #
+ if cnt.do_sync:
+ # wait for copy thread to catch up
+ return SYNC_LOOP
+ elif cnt.wanna_sync:
+ # copy thread wants sync, if not behind, do it
+ t = self.get_table_by_state(TABLE_WANNA_SYNC)
+ if self.cur_tick >= t.sync_tick_id:
+ self.change_table_state(dst_db, t, TABLE_DO_SYNC, self.cur_tick)
+ return SYNC_LOOP
+ else:
+ return SYNC_OK
+ elif cnt.catching_up:
+ # active copy, dont worry
+ return SYNC_OK
+ elif cnt.copy:
+ # active copy, dont worry
+ return SYNC_OK
+ elif cnt.missing:
+ # seems there is no active copy thread, launch new
+ t = self.get_table_by_state(TABLE_MISSING)
+ self.change_table_state(dst_db, t, TABLE_IN_COPY)
+
+ # the copy _may_ happen immidiately
+ self.launch_copy(t)
+
+ # there cannot be interesting events in current batch
+ # but maybe there's several tables, lets do them in one go
+ return SYNC_LOOP
+ else:
+ # seems everything is in sync
+ return SYNC_OK
+
+ def sync_from_copy_thread(self, cnt, dst_db):
+ "Copy thread sync logic."
+
+ #
+ # decide what to do - order is imortant
+ #
+ if cnt.do_sync:
+ # main thread is waiting, catch up, then handle over
+ t = self.get_table_by_state(TABLE_DO_SYNC)
+ if self.cur_tick == t.sync_tick_id:
+ self.change_table_state(dst_db, t, TABLE_OK)
+ return SYNC_EXIT
+ elif self.cur_tick < t.sync_tick_id:
+ return SYNC_OK
+ else:
+ self.log.error("copy_sync: cur_tick=%d sync_tick=%d" % (
+ self.cur_tick, t.sync_tick_id))
+ raise Exception('Invalid table state')
+ elif cnt.wanna_sync:
+ # wait for main thread to react
+ return SYNC_LOOP
+ elif cnt.catching_up:
+ # is there more work?
+ if self.work_state:
+ return SYNC_OK
+
+ # seems we have catched up
+ t = self.get_table_by_state(TABLE_CATCHING_UP)
+ self.change_table_state(dst_db, t, TABLE_WANNA_SYNC, self.cur_tick)
+ return SYNC_LOOP
+ elif cnt.copy:
+ # table is not copied yet, do it
+ t = self.get_table_by_state(TABLE_IN_COPY)
+ self.do_copy(t)
+
+ # forget previous value
+ self.work_state = 1
+
+ return SYNC_LOOP
+ else:
+ # nothing to do
+ return SYNC_EXIT
+
+ def handle_events(self, dst_curs, ev_list):
+ "Actual event processing happens here."
+
+ ignored_events = 0
+ self.sql_list = []
+ mirror_list = []
+ for ev in ev_list:
+ if not self.interesting(ev):
+ ignored_events += 1
+ ev.tag_done()
+ continue
+
+ if ev.type in ('I', 'U', 'D'):
+ self.handle_data_event(ev, dst_curs)
+ else:
+ self.handle_system_event(ev, dst_curs)
+
+ if self.mirror_queue:
+ mirror_list.append(ev)
+
+ # finalize table changes
+ self.flush_sql(dst_curs)
+ self.stat_add('ignored', ignored_events)
+
+ # put events into mirror queue if requested
+ if self.mirror_queue:
+ self.fill_mirror_queue(mirror_list, dst_curs)
+
+ def handle_data_event(self, ev, dst_curs):
+ fmt = self.sql_command[ev.type]
+ sql = fmt % (ev.extra1, ev.data)
+ self.sql_list.append(sql)
+ if len(self.sql_list) > 200:
+ self.flush_sql(dst_curs)
+ ev.tag_done()
+
+ def flush_sql(self, dst_curs):
+ if len(self.sql_list) == 0:
+ return
+
+ buf = "\n".join(self.sql_list)
+ self.sql_list = []
+
+ dst_curs.execute(buf)
+
+ def interesting(self, ev):
+ if ev.type not in ('I', 'U', 'D'):
+ return 1
+ t = self.get_table_by_name(ev.extra1)
+ if t:
+ return t.interesting(ev, self.cur_tick, self.copy_thread)
+ else:
+ return 0
+
+ def handle_system_event(self, ev, dst_curs):
+ "System event."
+
+ if ev.type == "T":
+ self.log.info("got new table event: "+ev.data)
+ # check tables to be dropped
+ name_list = []
+ for name in ev.data.split(','):
+ name_list.append(name.strip())
+
+ del_list = []
+ for tbl in self.table_list:
+ if tbl.name in name_list:
+ continue
+ del_list.append(tbl)
+
+ # separate loop to avoid changing while iterating
+ for tbl in del_list:
+ self.log.info("Removing table %s from set" % tbl.name)
+ self.remove_table(tbl, dst_curs)
+
+ ev.tag_done()
+ else:
+ self.log.warning("Unknows op %s" % ev.type)
+ ev.tag_failed("Unknown operation")
+
+ def remove_table(self, tbl, dst_curs):
+ del self.table_map[tbl.name]
+ self.table_list.remove(tbl)
+ q = "select londiste.subscriber_remove_table(%s, %s)"
+ dst_curs.execute(q, [self.pgq_queue_name, tbl.name])
+
+ def load_table_state(self, curs):
+ """Load table state from database.
+
+ Todo: if all tables are OK, there is no need
+ to load state on every batch.
+ """
+
+ q = """select table_name, snapshot, merge_state
+ from londiste.subscriber_get_table_list(%s)
+ """
+ curs.execute(q, [self.pgq_queue_name])
+
+ new_list = []
+ new_map = {}
+ for row in curs.dictfetchall():
+ t = self.get_table_by_name(row['table_name'])
+ if not t:
+ t = TableState(row['table_name'], self.log)
+ t.loaded_state(row['merge_state'], row['snapshot'])
+ new_list.append(t)
+ new_map[t.name] = t
+
+ self.table_list = new_list
+ self.table_map = new_map
+
+ def save_table_state(self, curs):
+ """Store changed table state in database."""
+
+ for t in self.table_list:
+ if not t.changed:
+ continue
+ merge_state = t.render_state()
+ self.log.info("storing state of %s: copy:%d new_state:%s" % (
+ t.name, self.copy_thread, merge_state))
+ q = "select londiste.subscriber_set_table_state(%s, %s, %s, %s)"
+ curs.execute(q, [self.pgq_queue_name,
+ t.name, t.str_snapshot, merge_state])
+ t.changed = 0
+
+ def change_table_state(self, dst_db, tbl, state, tick_id = None):
+ tbl.change_state(state, tick_id)
+ self.save_table_state(dst_db.cursor())
+ dst_db.commit()
+
+ self.log.info("Table %s status changed to '%s'" % (
+ tbl.name, tbl.render_state()))
+
+ def get_table_by_state(self, state):
+ "get first table with specific state"
+
+ for t in self.table_list:
+ if t.state == state:
+ return t
+ raise Exception('No table was found with state: %d' % state)
+
+ def get_table_by_name(self, name):
+ if name.find('.') < 0:
+ name = "public.%s" % name
+ if name in self.table_map:
+ return self.table_map[name]
+ return None
+
+ def fill_mirror_queue(self, ev_list, dst_curs):
+ # insert events
+ rows = []
+ fields = ['ev_type', 'ev_data', 'ev_extra1']
+ for ev in mirror_list:
+ rows.append((ev.type, ev.data, ev.extra1))
+ pgq.bulk_insert_events(dst_curs, rows, fields, self.mirror_queue)
+
+ # create tick
+ q = "select pgq.ticker(%s, %s)"
+ dst_curs.execute(q, [self.mirror_queue, self.cur_tick])
+
+ def launch_copy(self, tbl_stat):
+ self.log.info("Launching copy process")
+ script = sys.argv[0]
+ conf = self.cf.filename
+ if self.options.verbose:
+ cmd = "%s -d -v %s copy"
+ else:
+ cmd = "%s -d %s copy"
+ cmd = cmd % (script, conf)
+ self.log.debug("Launch args: "+repr(cmd))
+ res = os.system(cmd)
+ self.log.debug("Launch result: "+repr(res))
+
+if __name__ == '__main__':
+ script = Replicator(sys.argv[1:])
+ script.start()
+
diff --git a/python/londiste/repair.py b/python/londiste/repair.py
new file mode 100644
index 00000000..ec4bd404
--- /dev/null
+++ b/python/londiste/repair.py
@@ -0,0 +1,284 @@
+
+"""Repair data on subscriber.
+
+Walks tables by primary key and searcher
+missing inserts/updates/deletes.
+"""
+
+import sys, os, time, psycopg, skytools
+
+from syncer import Syncer
+
+__all__ = ['Repairer']
+
+def unescape(s):
+ return skytools.unescape_copy(s)
+
+def get_pkey_list(curs, tbl):
+ """Get list of pkey fields in right order."""
+
+ oid = skytools.get_table_oid(curs, tbl)
+ q = """SELECT k.attname FROM pg_index i, pg_attribute k
+ WHERE i.indrelid = %s AND k.attrelid = i.indexrelid
+ AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped
+ ORDER BY k.attnum"""
+ curs.execute(q, [oid])
+ list = []
+ for row in curs.fetchall():
+ list.append(row[0])
+ return list
+
+def get_column_list(curs, tbl):
+ """Get list of columns in right order."""
+
+ oid = skytools.get_table_oid(curs, tbl)
+ q = """SELECT a.attname FROM pg_attribute a
+ WHERE a.attrelid = %s
+ AND a.attnum > 0 AND NOT a.attisdropped
+ ORDER BY a.attnum"""
+ curs.execute(q, [oid])
+ list = []
+ for row in curs.fetchall():
+ list.append(row[0])
+ return list
+
+class Repairer(Syncer):
+ """Walks tables in primary key order and checks if data matches."""
+
+
+ def process_sync(self, tbl, src_db, dst_db):
+ """Actual comparision."""
+
+ src_curs = src_db.cursor()
+ dst_curs = dst_db.cursor()
+
+ self.log.info('Checking %s' % tbl)
+
+ self.common_fields = []
+ self.pkey_list = []
+ copy_tbl = self.gen_copy_tbl(tbl, src_curs, dst_curs)
+
+ dump_src = tbl + ".src"
+ dump_dst = tbl + ".dst"
+
+ self.log.info("Dumping src table: %s" % tbl)
+ self.dump_table(tbl, copy_tbl, src_curs, dump_src)
+ src_db.commit()
+ self.log.info("Dumping dst table: %s" % tbl)
+ self.dump_table(tbl, copy_tbl, dst_curs, dump_dst)
+ dst_db.commit()
+
+ self.log.info("Sorting src table: %s" % tbl)
+
+ s_in, s_out = os.popen4("sort --version")
+ s_ver = s_out.read()
+ del s_in, s_out
+ if s_ver.find("coreutils") > 0:
+ args = "-S 30%"
+ else:
+ args = ""
+ os.system("sort %s -T . -o %s.sorted %s" % (args, dump_src, dump_src))
+ self.log.info("Sorting dst table: %s" % tbl)
+ os.system("sort %s -T . -o %s.sorted %s" % (args, dump_dst, dump_dst))
+
+ self.dump_compare(tbl, dump_src + ".sorted", dump_dst + ".sorted")
+
+ os.unlink(dump_src)
+ os.unlink(dump_dst)
+ os.unlink(dump_src + ".sorted")
+ os.unlink(dump_dst + ".sorted")
+
+ def gen_copy_tbl(self, tbl, src_curs, dst_curs):
+ self.pkey_list = get_pkey_list(src_curs, tbl)
+ dst_pkey = get_pkey_list(dst_curs, tbl)
+ if dst_pkey != self.pkey_list:
+ self.log.error('pkeys do not match')
+ sys.exit(1)
+
+ src_cols = get_column_list(src_curs, tbl)
+ dst_cols = get_column_list(dst_curs, tbl)
+ field_list = []
+ for f in self.pkey_list:
+ field_list.append(f)
+ for f in src_cols:
+ if f in self.pkey_list:
+ continue
+ if f in dst_cols:
+ field_list.append(f)
+
+ self.common_fields = field_list
+
+ tbl_expr = "%s (%s)" % (tbl, ",".join(field_list))
+
+ self.log.debug("using copy expr: %s" % tbl_expr)
+
+ return tbl_expr
+
+ def dump_table(self, tbl, copy_tbl, curs, fn):
+ f = open(fn, "w", 64*1024)
+ curs.copy_to(f, copy_tbl)
+ size = f.tell()
+ f.close()
+ self.log.info('Got %d bytes' % size)
+
+ def get_row(self, ln):
+ t = ln[:-1].split('\t')
+ row = {}
+ for i in range(len(self.common_fields)):
+ row[self.common_fields[i]] = t[i]
+ return row
+
+ def dump_compare(self, tbl, src_fn, dst_fn):
+ self.log.info("Comparing dumps: %s" % tbl)
+ self.cnt_insert = 0
+ self.cnt_update = 0
+ self.cnt_delete = 0
+ self.total_src = 0
+ self.total_dst = 0
+ f1 = open(src_fn, "r", 64*1024)
+ f2 = open(dst_fn, "r", 64*1024)
+ src_ln = f1.readline()
+ dst_ln = f2.readline()
+ if src_ln: self.total_src += 1
+ if dst_ln: self.total_dst += 1
+
+ fix = "fix.%s.sql" % tbl
+ if os.path.isfile(fix):
+ os.unlink(fix)
+
+ while src_ln or dst_ln:
+ keep_src = keep_dst = 0
+ if src_ln != dst_ln:
+ src_row = self.get_row(src_ln)
+ dst_row = self.get_row(dst_ln)
+
+ cmp = self.cmp_keys(src_row, dst_row)
+ if cmp > 0:
+ # src > dst
+ self.got_missed_delete(tbl, dst_row)
+ keep_src = 1
+ elif cmp < 0:
+ # src < dst
+ self.got_missed_insert(tbl, src_row)
+ keep_dst = 1
+ else:
+ if self.cmp_data(src_row, dst_row) != 0:
+ self.got_missed_update(tbl, src_row, dst_row)
+
+ if not keep_src:
+ src_ln = f1.readline()
+ if src_ln: self.total_src += 1
+ if not keep_dst:
+ dst_ln = f2.readline()
+ if dst_ln: self.total_dst += 1
+
+ self.log.info("finished %s: src: %d rows, dst: %d rows,"\
+ " missed: %d inserts, %d updates, %d deletes" % (
+ tbl, self.total_src, self.total_dst,
+ self.cnt_insert, self.cnt_update, self.cnt_delete))
+
+ def got_missed_insert(self, tbl, src_row):
+ self.cnt_insert += 1
+ fld_list = self.common_fields
+ val_list = []
+ for f in fld_list:
+ v = unescape(src_row[f])
+ val_list.append(skytools.quote_literal(v))
+ q = "insert into %s (%s) values (%s);" % (
+ tbl, ", ".join(fld_list), ", ".join(val_list))
+ self.show_fix(tbl, q, 'insert')
+
+ def got_missed_update(self, tbl, src_row, dst_row):
+ self.cnt_update += 1
+ fld_list = self.common_fields
+ set_list = []
+ whe_list = []
+ for f in self.pkey_list:
+ self.addcmp(whe_list, f, unescape(src_row[f]))
+ for f in fld_list:
+ v1 = src_row[f]
+ v2 = dst_row[f]
+ if self.cmp_value(v1, v2) == 0:
+ continue
+
+ self.addeq(set_list, f, unescape(v1))
+ self.addcmp(whe_list, f, unescape(v2))
+
+ q = "update only %s set %s where %s;" % (
+ tbl, ", ".join(set_list), " and ".join(whe_list))
+ self.show_fix(tbl, q, 'update')
+
+ def got_missed_delete(self, tbl, dst_row, pkey_list):
+ self.cnt_delete += 1
+ whe_list = []
+ for f in self.pkey_list:
+ self.addcmp(whe_list, f, unescape(dst_row[f]))
+ q = "delete from only %s where %s;" % (tbl, " and ".join(whe_list))
+ self.show_fix(tbl, q, 'delete')
+
+ def show_fix(self, tbl, q, desc):
+ #self.log.warning("missed %s: %s" % (desc, q))
+ fn = "fix.%s.sql" % tbl
+ open(fn, "a").write("%s\n" % q)
+
+ def addeq(self, list, f, v):
+ vq = skytools.quote_literal(v)
+ s = "%s = %s" % (f, vq)
+ list.append(s)
+
+ def addcmp(self, list, f, v):
+ if v is None:
+ s = "%s is null" % f
+ else:
+ vq = skytools.quote_literal(v)
+ s = "%s = %s" % (f, vq)
+ list.append(s)
+
+ def cmp_data(self, src_row, dst_row):
+ for k in self.common_fields:
+ v1 = src_row[k]
+ v2 = dst_row[k]
+ if self.cmp_value(v1, v2) != 0:
+ return -1
+ return 0
+
+ def cmp_value(self, v1, v2):
+ if v1 == v2:
+ return 0
+
+ # try to work around tz vs. notz
+ z1 = len(v1)
+ z2 = len(v2)
+ if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+':
+ v1 = v1[:-3]
+ if v1 == v2:
+ return 0
+ elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+':
+ v2 = v2[:-3]
+ if v1 == v2:
+ return 0
+
+ return -1
+
+ def cmp_keys(self, src_row, dst_row):
+ """Compare primary keys of the rows.
+
+ Returns 1 if src > dst, -1 if src < dst and 0 if src == dst"""
+
+ # None means table is done. tag it larger than any existing row.
+ if src_row is None:
+ if dst_row is None:
+ return 0
+ return 1
+ elif dst_row is None:
+ return -1
+
+ for k in self.pkey_list:
+ v1 = src_row[k]
+ v2 = dst_row[k]
+ if v1 < v2:
+ return -1
+ elif v1 > v2:
+ return 1
+ return 0
+
diff --git a/python/londiste/setup.py b/python/londiste/setup.py
new file mode 100644
index 00000000..ed44b093
--- /dev/null
+++ b/python/londiste/setup.py
@@ -0,0 +1,580 @@
+#! /usr/bin/env python
+
+"""Londiste setup and sanity checker.
+
+"""
+import sys, os, skytools
+from installer import *
+
+__all__ = ['ProviderSetup', 'SubscriberSetup']
+
+def find_column_types(curs, table):
+ table_oid = skytools.get_table_oid(curs, table)
+ if table_oid == None:
+ return None
+
+ key_sql = """
+ SELECT k.attname FROM pg_index i, pg_attribute k
+ WHERE i.indrelid = %d AND k.attrelid = i.indexrelid
+ AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped
+ """ % table_oid
+
+ # find columns
+ q = """
+ SELECT a.attname as name,
+ CASE WHEN k.attname IS NOT NULL
+ THEN 'k' ELSE 'v' END AS type
+ FROM pg_attribute a LEFT JOIN (%s) k ON (k.attname = a.attname)
+ WHERE a.attrelid = %d AND a.attnum > 0 AND NOT a.attisdropped
+ ORDER BY a.attnum
+ """ % (key_sql, table_oid)
+ curs.execute(q)
+ rows = curs.dictfetchall()
+ return rows
+
+def make_type_string(col_rows):
+ res = map(lambda x: x['type'], col_rows)
+ return "".join(res)
+
+class CommonSetup(skytools.DBScript):
+ def __init__(self, args):
+ skytools.DBScript.__init__(self, 'londiste', args)
+ self.set_single_loop(1)
+ self.pidfile = self.pidfile + ".setup"
+
+ self.pgq_queue_name = self.cf.get("pgq_queue_name")
+ self.consumer_id = self.cf.get("pgq_consumer_id", self.job_name)
+ self.fake = self.cf.getint('fake', 0)
+
+ if len(self.args) < 3:
+ self.log.error("need subcommand")
+ sys.exit(1)
+
+ def run(self):
+ self.admin()
+
+ def fetch_provider_table_list(self, curs):
+ q = """select table_name, trigger_name
+ from londiste.provider_get_table_list(%s)"""
+ curs.execute(q, [self.pgq_queue_name])
+ return curs.dictfetchall()
+
+ def get_provider_table_list(self):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ list = self.fetch_provider_table_list(src_curs)
+ src_db.commit()
+ res = []
+ for row in list:
+ res.append(row['table_name'])
+ return res
+
+ def get_provider_seqs(self, curs):
+ q = """SELECT * from londiste.provider_get_seq_list(%s)"""
+ curs.execute(q, [self.pgq_queue_name])
+ res = []
+ for row in curs.fetchall():
+ res.append(row[0])
+ return res
+
+ def get_all_seqs(self, curs):
+ q = """SELECT n.nspname || '.'|| c.relname
+ from pg_class c, pg_namespace n
+ where n.oid = c.relnamespace
+ and c.relkind = 'S'
+ order by 1"""
+ curs.execute(q)
+ res = []
+ for row in curs.fetchall():
+ res.append(row[0])
+ return res
+
+ def check_provider_queue(self):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ q = "select count(1) from pgq.get_queue_info(%s)"
+ src_curs.execute(q, [self.pgq_queue_name])
+ ok = src_curs.fetchone()[0]
+ src_db.commit()
+
+ if not ok:
+ self.log.error('Event queue does not exist yet')
+ sys.exit(1)
+
+ def fetch_subscriber_tables(self, curs):
+ q = "select * from londiste.subscriber_get_table_list(%s)"
+ curs.execute(q, [self.pgq_queue_name])
+ return curs.dictfetchall()
+
+ def get_subscriber_table_list(self):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ list = self.fetch_subscriber_tables(dst_curs)
+ dst_db.commit()
+ res = []
+ for row in list:
+ res.append(row['table_name'])
+ return res
+
+ def init_optparse(self, parser=None):
+ p = skytools.DBScript.init_optparse(self, parser)
+ p.add_option("--expect-sync", action="store_true", dest="expect_sync",
+ help = "no copy needed", default=False)
+ p.add_option("--force", action="store_true",
+ help="force", default=False)
+ return p
+
+
+#
+# Provider commands
+#
+
+class ProviderSetup(CommonSetup):
+
+ def admin(self):
+ cmd = self.args[2]
+ if cmd == "tables":
+ self.provider_show_tables()
+ elif cmd == "add":
+ self.provider_add_tables(self.args[3:])
+ elif cmd == "remove":
+ self.provider_remove_tables(self.args[3:])
+ elif cmd == "add-seq":
+ for seq in self.args[3:]:
+ self.provider_add_seq(seq)
+ self.provider_notify_change()
+ elif cmd == "remove-seq":
+ for seq in self.args[3:]:
+ self.provider_remove_seq(seq)
+ self.provider_notify_change()
+ elif cmd == "install":
+ self.provider_install()
+ elif cmd == "seqs":
+ self.provider_list_seqs()
+ else:
+ self.log.error('bad subcommand')
+ sys.exit(1)
+
+ def provider_list_seqs(self):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ list = self.get_provider_seqs(src_curs)
+ src_db.commit()
+
+ for seq in list:
+ print seq
+
+ def provider_install(self):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ install_provider(src_curs, self.log)
+
+ # create event queue
+ q = "select pgq.create_queue(%s)"
+ self.exec_provider(q, [self.pgq_queue_name])
+
+ def provider_add_tables(self, table_list):
+ self.check_provider_queue()
+
+ cur_list = self.get_provider_table_list()
+ for tbl in table_list:
+ if tbl.find('.') < 0:
+ tbl = "public." + tbl
+ if tbl not in cur_list:
+ self.log.info('Adding %s' % tbl)
+ self.provider_add_table(tbl)
+ else:
+ self.log.info("Table %s already added" % tbl)
+ self.provider_notify_change()
+
+ def provider_remove_tables(self, table_list):
+ self.check_provider_queue()
+
+ cur_list = self.get_provider_table_list()
+ for tbl in table_list:
+ if tbl.find('.') < 0:
+ tbl = "public." + tbl
+ if tbl not in cur_list:
+ self.log.info('%s already removed' % tbl)
+ else:
+ self.log.info("Removing %s" % tbl)
+ self.provider_remove_table(tbl)
+ self.provider_notify_change()
+
+ def provider_add_table(self, tbl):
+ q = "select londiste.provider_add_table(%s, %s)"
+ self.exec_provider(q, [self.pgq_queue_name, tbl])
+
+ def provider_remove_table(self, tbl):
+ q = "select londiste.provider_remove_table(%s, %s)"
+ self.exec_provider(q, [self.pgq_queue_name, tbl])
+
+ def provider_show_tables(self):
+ self.check_provider_queue()
+ list = self.get_provider_table_list()
+ for tbl in list:
+ print tbl
+
+ def provider_notify_change(self):
+ q = "select londiste.provider_notify_change(%s)"
+ self.exec_provider(q, [self.pgq_queue_name])
+
+ def provider_add_seq(self, seq):
+ seq = skytools.fq_name(seq)
+ q = "select londiste.provider_add_seq(%s, %s)"
+ self.exec_provider(q, [self.pgq_queue_name, seq])
+
+ def provider_remove_seq(self, seq):
+ seq = skytools.fq_name(seq)
+ q = "select londiste.provider_remove_seq(%s, %s)"
+ self.exec_provider(q, [self.pgq_queue_name, seq])
+
+ def exec_provider(self, sql, args):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+
+ src_curs.execute(sql, args)
+
+ if self.fake:
+ src_db.rollback()
+ else:
+ src_db.commit()
+
+#
+# Subscriber commands
+#
+
+class SubscriberSetup(CommonSetup):
+
+ def admin(self):
+ cmd = self.args[2]
+ if cmd == "tables":
+ self.subscriber_show_tables()
+ elif cmd == "missing":
+ self.subscriber_missing_tables()
+ elif cmd == "add":
+ self.subscriber_add_tables(self.args[3:])
+ elif cmd == "remove":
+ self.subscriber_remove_tables(self.args[3:])
+ elif cmd == "resync":
+ self.subscriber_resync_tables(self.args[3:])
+ elif cmd == "register":
+ self.subscriber_register()
+ elif cmd == "unregister":
+ self.subscriber_unregister()
+ elif cmd == "install":
+ self.subscriber_install()
+ elif cmd == "check":
+ self.check_tables(self.get_provider_table_list())
+ elif cmd == "fkeys":
+ self.collect_fkeys(self.get_provider_table_list())
+ elif cmd == "seqs":
+ self.subscriber_list_seqs()
+ elif cmd == "add-seq":
+ self.subscriber_add_seq(self.args[3:])
+ elif cmd == "remove-seq":
+ self.subscriber_remove_seq(self.args[3:])
+ else:
+ self.log.error('bad subcommand: ' + cmd)
+ sys.exit(1)
+
+ def collect_fkeys(self, table_list):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+
+ oid_list = []
+ for tbl in table_list:
+ try:
+ oid = skytools.get_table_oid(dst_curs, tbl)
+ if oid:
+ oid_list.append(str(oid))
+ except:
+ pass
+ if len(oid_list) == 0:
+ print "No tables"
+ return
+ oid_str = ",".join(oid_list)
+
+ q = "SELECT n.nspname || '.' || t.relname as tbl, c.conname as con,"\
+ " pg_get_constraintdef(c.oid) as def"\
+ " FROM pg_constraint c, pg_class t, pg_namespace n"\
+ " WHERE c.contype = 'f' and c.conrelid in (%s)"\
+ " AND t.oid = c.conrelid AND n.oid = t.relnamespace" % oid_str
+ dst_curs.execute(q)
+ res = dst_curs.dictfetchall()
+ dst_db.commit()
+
+ print "-- dropping"
+ for row in res:
+ q = "ALTER TABLE ONLY %(tbl)s DROP CONSTRAINT %(con)s;"
+ print q % row
+ print "-- creating"
+ for row in res:
+ q = "ALTER TABLE ONLY %(tbl)s ADD CONSTRAINT %(con)s %(def)s;"
+ print q % row
+
+ def check_tables(self, table_list):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+
+ failed = 0
+ for tbl in table_list:
+ self.log.info('Checking %s' % tbl)
+ if not skytools.exists_table(src_curs, tbl):
+ self.log.error('Table %s missing from provider side' % tbl)
+ failed += 1
+ elif not skytools.exists_table(dst_curs, tbl):
+ self.log.error('Table %s missing from subscriber side' % tbl)
+ failed += 1
+ else:
+ failed += self.check_table_columns(src_curs, dst_curs, tbl)
+ failed += self.check_table_triggers(dst_curs, tbl)
+
+ src_db.commit()
+ dst_db.commit()
+
+ return failed
+
+ def check_table_triggers(self, dst_curs, tbl):
+ oid = skytools.get_table_oid(dst_curs, tbl)
+ if not oid:
+ self.log.error('Table %s not found' % tbl)
+ return 1
+ q = "select count(1) from pg_trigger where tgrelid = %s"
+ dst_curs.execute(q, [oid])
+ got = dst_curs.fetchone()[0]
+ if got:
+ self.log.error('found trigger on table %s (%s)' % (tbl, str(oid)))
+ return 1
+ else:
+ return 0
+
+ def check_table_columns(self, src_curs, dst_curs, tbl):
+ src_colrows = find_column_types(src_curs, tbl)
+ dst_colrows = find_column_types(dst_curs, tbl)
+
+ src_cols = make_type_string(src_colrows)
+ dst_cols = make_type_string(dst_colrows)
+ if src_cols.find('k') < 0:
+ self.log.error('provider table %s has no primary key (%s)' % (
+ tbl, src_cols))
+ return 1
+ if dst_cols.find('k') < 0:
+ self.log.error('subscriber table %s has no primary key (%s)' % (
+ tbl, dst_cols))
+ return 1
+
+ if src_cols != dst_cols:
+ self.log.warning('table %s structure is not same (%s/%s)'\
+ ', trying to continue' % (tbl, src_cols, dst_cols))
+
+ err = 0
+ for row in src_colrows:
+ found = 0
+ for row2 in dst_colrows:
+ if row2['name'] == row['name']:
+ found = 1
+ break
+ if not found:
+ err = 1
+ self.log.error('%s: column %s on provider not on subscriber'
+ % (tbl, row['name']))
+ elif row['type'] != row2['type']:
+ err = 1
+ self.log.error('%s: pk different on column %s'
+ % (tbl, row['name']))
+
+ return err
+
+ def subscriber_install(self):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+
+ install_subscriber(dst_curs, self.log)
+
+ if self.fake:
+ self.log.debug('rollback')
+ dst_db.rollback()
+ else:
+ self.log.debug('commit')
+ dst_db.commit()
+
+ def subscriber_register(self):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ src_curs.execute("select pgq.register_consumer(%s, %s)",
+ [self.pgq_queue_name, self.consumer_id])
+ src_db.commit()
+
+ def subscriber_unregister(self):
+ q = "select londiste.subscriber_set_table_state(%s, %s, NULL, NULL)"
+
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ tbl_rows = self.fetch_subscriber_tables(dst_curs)
+ for row in tbl_rows:
+ dst_curs.execute(q, [self.pgq_queue_name, row['table_name']])
+ dst_db.commit()
+
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ src_curs.execute("select pgq.unregister_consumer(%s, %s)",
+ [self.pgq_queue_name, self.consumer_id])
+ src_db.commit()
+
+ def subscriber_show_tables(self):
+ list = self.get_subscriber_table_list()
+ for tbl in list:
+ print tbl
+
+ def subscriber_missing_tables(self):
+ provider_tables = self.get_provider_table_list()
+ subscriber_tables = self.get_subscriber_table_list()
+ for tbl in provider_tables:
+ if tbl not in subscriber_tables:
+ print tbl
+
+ def subscriber_add_tables(self, table_list):
+ provider_tables = self.get_provider_table_list()
+ subscriber_tables = self.get_subscriber_table_list()
+
+ err = 0
+ for tbl in table_list:
+ tbl = skytools.fq_name(tbl)
+ if tbl not in provider_tables:
+ err = 1
+ self.log.error("Table %s not attached to queue" % tbl)
+ if err:
+ if self.options.force:
+ self.log.warning('--force used, ignoring errors')
+ else:
+ sys.exit(1)
+
+ err = self.check_tables(table_list)
+ if err:
+ if self.options.force:
+ self.log.warning('--force used, ignoring errors')
+ else:
+ sys.exit(1)
+
+ for tbl in table_list:
+ tbl = skytools.fq_name(tbl)
+ if tbl in subscriber_tables:
+ self.log.info("Table %s already added" % tbl)
+ else:
+ self.log.info("Adding %s" % tbl)
+ self.subscriber_add_one_table(tbl)
+
+ def subscriber_remove_tables(self, table_list):
+ subscriber_tables = self.get_subscriber_table_list()
+ for tbl in table_list:
+ tbl = skytools.fq_name(tbl)
+ if tbl in subscriber_tables:
+ self.subscriber_remove_one_table(tbl)
+ else:
+ self.log.info("Table %s already removed" % tbl)
+
+ def subscriber_resync_tables(self, table_list):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ list = self.fetch_subscriber_tables(dst_curs)
+ for tbl in table_list:
+ tbl = skytools.fq_name(tbl)
+ tbl_row = None
+ for row in list:
+ if row['table_name'] == tbl:
+ tbl_row = row
+ break
+ if not tbl_row:
+ self.log.warning("Table %s not found" % tbl)
+ elif tbl_row['merge_state'] != 'ok':
+ self.log.warning("Table %s is not in stable state" % tbl)
+ else:
+ self.log.info("Resyncing %s" % tbl)
+ q = "select londiste.subscriber_set_table_state(%s, %s, NULL, NULL)"
+ dst_curs.execute(q, [self.pgq_queue_name, tbl])
+ dst_db.commit()
+
+ def subscriber_add_one_table(self, tbl):
+ q = "select londiste.subscriber_add_table(%s, %s)"
+
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ dst_curs.execute(q, [self.pgq_queue_name, tbl])
+ if self.options.expect_sync:
+ q = "select londiste.subscriber_set_table_state(%s, %s, null, 'ok')"
+ dst_curs.execute(q, [self.pgq_queue_name, tbl])
+ dst_db.commit()
+
+ def subscriber_remove_one_table(self, tbl):
+ q = "select londiste.subscriber_remove_table(%s, %s)"
+
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ dst_curs.execute(q, [self.pgq_queue_name, tbl])
+ dst_db.commit()
+
+ def get_subscriber_seq_list(self):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ q = "SELECT * from londiste.subscriber_get_seq_list(%s)"
+ dst_curs.execute(q, [self.pgq_queue_name])
+ list = dst_curs.fetchall()
+ dst_db.commit()
+ res = []
+ for row in list:
+ res.append(row[0])
+ return res
+
+ def subscriber_list_seqs(self):
+ list = self.get_subscriber_seq_list()
+ for seq in list:
+ print seq
+
+ def subscriber_add_seq(self, seq_list):
+ src_db = self.get_database('provider_db')
+ src_curs = src_db.cursor()
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+
+ prov_list = self.get_provider_seqs(src_curs)
+ src_db.commit()
+
+ full_list = self.get_all_seqs(dst_curs)
+ cur_list = self.get_subscriber_seq_list()
+
+ for seq in seq_list:
+ seq = skytools.fq_name(seq)
+ if seq not in prov_list:
+ self.log.error('Seq %s does not exist on provider side' % seq)
+ continue
+ if seq not in full_list:
+ self.log.error('Seq %s does not exist on subscriber side' % seq)
+ continue
+ if seq in cur_list:
+ self.log.info('Seq %s already subscribed' % seq)
+ continue
+
+ self.log.info('Adding sequence: %s' % seq)
+ q = "select londiste.subscriber_add_seq(%s, %s)"
+ dst_curs.execute(q, [self.pgq_queue_name, seq])
+
+ dst_db.commit()
+
+ def subscriber_remove_seq(self, seq_list):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ cur_list = self.get_subscriber_seq_list()
+
+ for seq in seq_list:
+ seq = skytools.fq_name(seq)
+ if seq not in cur_list:
+ self.log.warning('Seq %s not subscribed')
+ else:
+ self.log.info('Removing sequence: %s' % seq)
+ q = "select londiste.subscriber_remove_seq(%s, %s)"
+ dst_curs.execute(q, [self.pgq_queue_name, seq])
+ dst_db.commit()
+
diff --git a/python/londiste/syncer.py b/python/londiste/syncer.py
new file mode 100644
index 00000000..eaee3468
--- /dev/null
+++ b/python/londiste/syncer.py
@@ -0,0 +1,177 @@
+
+"""Catch moment when tables are in sync on master and slave.
+"""
+
+import sys, time, skytools
+
+class Syncer(skytools.DBScript):
+ """Walks tables in primary key order and checks if data matches."""
+
+ def __init__(self, args):
+ skytools.DBScript.__init__(self, 'londiste', args)
+ self.set_single_loop(1)
+
+ self.pgq_queue_name = self.cf.get("pgq_queue_name")
+ self.pgq_consumer_id = self.cf.get('pgq_consumer_id', self.job_name)
+
+ if self.pidfile:
+ self.pidfile += ".repair"
+
+ def init_optparse(self, p=None):
+ p = skytools.DBScript.init_optparse(self, p)
+ p.add_option("--force", action="store_true", help="ignore lag")
+ return p
+
+ def check_consumer(self, src_db):
+ src_curs = src_db.cursor()
+
+ # before locking anything check if consumer is working ok
+ q = "select extract(epoch from ticker_lag) from pgq.get_queue_list()"\
+ " where queue_name = %s"
+ src_curs.execute(q, [self.pgq_queue_name])
+ ticker_lag = src_curs.fetchone()[0]
+ q = "select extract(epoch from lag)"\
+ " from pgq.get_consumer_list()"\
+ " where queue_name = %s"\
+ " and consumer_name = %s"
+ src_curs.execute(q, [self.pgq_queue_name, self.pgq_consumer_id])
+ res = src_curs.fetchall()
+ src_db.commit()
+
+ if len(res) == 0:
+ self.log.error('No such consumer')
+ sys.exit(1)
+ consumer_lag = res[0][0]
+
+ if consumer_lag > ticker_lag + 10 and not self.options.force:
+ self.log.error('Consumer lagging too much, cannot proceed')
+ sys.exit(1)
+
+ def get_subscriber_table_state(self):
+ dst_db = self.get_database('subscriber_db')
+ dst_curs = dst_db.cursor()
+ q = "select * from londiste.subscriber_get_table_list(%s)"
+ dst_curs.execute(q, [self.pgq_queue_name])
+ res = dst_curs.dictfetchall()
+ dst_db.commit()
+ return res
+
+ def work(self):
+ src_loc = self.cf.get('provider_db')
+ lock_db = self.get_database('provider_db', cache='lock_db')
+ src_db = self.get_database('provider_db')
+ dst_db = self.get_database('subscriber_db')
+
+ self.check_consumer(src_db)
+
+ state_list = self.get_subscriber_table_state()
+ state_map = {}
+ full_list = []
+ for ts in state_list:
+ name = ts['table_name']
+ full_list.append(name)
+ state_map[name] = ts
+
+ if len(self.args) > 2:
+ tlist = self.args[2:]
+ else:
+ tlist = full_list
+
+ for tbl in tlist:
+ if not tbl in state_map:
+ self.log.warning('Table not subscribed: %s' % tbl)
+ continue
+ st = state_map[tbl]
+ if st['merge_state'] != 'ok':
+ self.log.info('Table %s not synced yet, no point' % tbl)
+ continue
+ self.check_table(tbl, lock_db, src_db, dst_db)
+ lock_db.commit()
+ src_db.commit()
+ dst_db.commit()
+
+ def check_table(self, tbl, lock_db, src_db, dst_db):
+ """Get transaction to same state, then process."""
+
+
+ lock_curs = lock_db.cursor()
+ src_curs = src_db.cursor()
+ dst_curs = dst_db.cursor()
+
+ if not skytools.exists_table(src_curs, tbl):
+ self.log.warning("Table %s does not exist on provider side" % tbl)
+ return
+ if not skytools.exists_table(dst_curs, tbl):
+ self.log.warning("Table %s does not exist on subscriber side" % tbl)
+ return
+
+ # lock table in separate connection
+ self.log.info('Locking %s' % tbl)
+ lock_db.commit()
+ lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % tbl)
+ lock_time = time.time()
+
+ # now wait until consumer has updated target table until locking
+ self.log.info('Syncing %s' % tbl)
+
+ # consumer must get futher than this tick
+ src_curs.execute("select pgq.ticker(%s)", [self.pgq_queue_name])
+ tick_id = src_curs.fetchone()[0]
+ src_db.commit()
+ # avoid depending on ticker by inserting second tick also
+ time.sleep(0.1)
+ src_curs.execute("select pgq.ticker(%s)", [self.pgq_queue_name])
+ src_db.commit()
+ src_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')")
+ tpos = src_curs.fetchone()[0]
+ src_db.commit()
+ # now wait
+ while 1:
+ time.sleep(0.2)
+
+ q = """select now() - lag > %s, now(), lag
+ from pgq.get_consumer_list()
+ where consumer_name = %s
+ and queue_name = %s"""
+ src_curs.execute(q, [tpos, self.pgq_consumer_id, self.pgq_queue_name])
+ res = src_curs.fetchall()
+ src_db.commit()
+
+ if len(res) == 0:
+ raise Exception('No such consumer')
+
+ row = res[0]
+ self.log.debug("tpos=%s now=%s lag=%s ok=%s" % (tpos, row[1], row[2], row[0]))
+ if row[0]:
+ break
+
+ # loop max 10 secs
+ if time.time() > lock_time + 10 and not self.options.force:
+ self.log.error('Consumer lagging too much, exiting')
+ lock_db.rollback()
+ sys.exit(1)
+
+ # take snapshot on provider side
+ src_curs.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
+ src_curs.execute("SELECT 1")
+
+ # take snapshot on subscriber side
+ dst_db.commit()
+ dst_curs.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE")
+ dst_curs.execute("SELECT 1")
+
+ # release lock
+ lock_db.commit()
+
+ # do work
+ self.process_sync(tbl, src_db, dst_db)
+
+ # done
+ src_db.commit()
+ dst_db.commit()
+
+ def process_sync(self, tbl, src_db, dst_db):
+ """It gets 2 connections in state where tbl should be in same state.
+ """
+ raise Exception('process_sync not implemented')
+
diff --git a/python/londiste/table_copy.py b/python/londiste/table_copy.py
new file mode 100644
index 00000000..1754baaf
--- /dev/null
+++ b/python/londiste/table_copy.py
@@ -0,0 +1,107 @@
+#! /usr/bin/env python
+
+"""Do a full table copy.
+
+For internal usage.
+"""
+
+import sys, os, skytools
+
+from skytools.dbstruct import *
+from playback import *
+
+__all__ = ['CopyTable']
+
+class CopyTable(Replicator):
+ def __init__(self, args, copy_thread = 1):
+ Replicator.__init__(self, args)
+
+ if copy_thread:
+ self.pidfile += ".copy"
+ self.consumer_id += "_copy"
+ self.copy_thread = 1
+
+ def init_optparse(self, parser=None):
+ p = Replicator.init_optparse(self, parser)
+ p.add_option("--skip-truncate", action="store_true", dest="skip_truncate",
+ help = "avoid truncate", default=False)
+ return p
+
+ def do_copy(self, tbl_stat):
+ src_db = self.get_database('provider_db')
+ dst_db = self.get_database('subscriber_db')
+
+ # it should not matter to pgq
+ src_db.commit()
+ dst_db.commit()
+
+ # change to SERIALIZABLE isolation level
+ src_db.set_isolation_level(2)
+ src_db.commit()
+
+ # initial sync copy
+ src_curs = src_db.cursor()
+ dst_curs = dst_db.cursor()
+
+ self.log.info("Starting full copy of %s" % tbl_stat.name)
+
+ # find dst struct
+ src_struct = TableStruct(src_curs, tbl_stat.name)
+ dst_struct = TableStruct(dst_curs, tbl_stat.name)
+
+ # check if columns match
+ dlist = dst_struct.get_column_list()
+ for c in src_struct.get_column_list():
+ if c not in dlist:
+ raise Exception('Column %s does not exist on dest side' % c)
+
+ # drop unnecessary stuff
+ objs = T_CONSTRAINT | T_INDEX | T_TRIGGER | T_RULE
+ dst_struct.drop(dst_curs, objs, log = self.log)
+
+ # do truncate & copy
+ self.real_copy(src_curs, dst_curs, tbl_stat.name)
+
+ # get snapshot
+ src_curs.execute("select get_current_snapshot()")
+ snapshot = src_curs.fetchone()[0]
+ src_db.commit()
+
+ # restore READ COMMITTED behaviour
+ src_db.set_isolation_level(1)
+ src_db.commit()
+
+ # create previously dropped objects
+ dst_struct.create(dst_curs, objs, log = self.log)
+
+ # set state
+ tbl_stat.change_snapshot(snapshot)
+ if self.copy_thread:
+ tbl_stat.change_state(TABLE_CATCHING_UP)
+ else:
+ tbl_stat.change_state(TABLE_OK)
+ self.save_table_state(dst_curs)
+ dst_db.commit()
+
+ def real_copy(self, srccurs, dstcurs, tablename):
+ "Main copy logic."
+
+ # drop data
+ if self.options.skip_truncate:
+ self.log.info("%s: skipping truncate" % tablename)
+ else:
+ self.log.info("%s: truncating" % tablename)
+ dstcurs.execute("truncate " + tablename)
+
+ # do copy
+ self.log.info("%s: start copy" % tablename)
+ col_list = skytools.get_table_columns(srccurs, tablename)
+ stats = skytools.full_copy(tablename, srccurs, dstcurs, col_list)
+ if stats:
+ self.log.info("%s: copy finished: %d bytes, %d rows" % (
+ tablename, stats[0], stats[1]))
+
+if __name__ == '__main__':
+ script = CopyTable(sys.argv[1:])
+ script.start()
+
diff --git a/python/pgq/__init__.py b/python/pgq/__init__.py
new file mode 100644
index 00000000..f0e9c1a6
--- /dev/null
+++ b/python/pgq/__init__.py
@@ -0,0 +1,6 @@
+"""PgQ framework for Python."""
+
+from pgq.event import *
+from pgq.consumer import *
+from pgq.producer import *
+
diff --git a/python/pgq/consumer.py b/python/pgq/consumer.py
new file mode 100644
index 00000000..bd49dccf
--- /dev/null
+++ b/python/pgq/consumer.py
@@ -0,0 +1,410 @@
+
+"""PgQ consumer framework for Python.
+
+API problems(?):
+ - process_event() and process_batch() should have db as argument.
+ - should ev.tag*() update db immidiately?
+
+"""
+
+import sys, time, skytools
+
+from pgq.event import *
+
+__all__ = ['Consumer', 'RemoteConsumer', 'SerialConsumer']
+
+class Consumer(skytools.DBScript):
+ """Consumer base class.
+ """
+
+ def __init__(self, service_name, db_name, args):
+ """Initialize new consumer.
+
+ @param service_name: service_name for DBScript
+ @param db_name: name of database for get_database()
+ @param args: cmdline args for DBScript
+ """
+
+ skytools.DBScript.__init__(self, service_name, args)
+
+ self.db_name = db_name
+ self.reg_list = []
+ self.consumer_id = self.cf.get("pgq_consumer_id", self.job_name)
+ self.pgq_queue_name = self.cf.get("pgq_queue_name")
+
+ def attach(self):
+ """Attach consumer to interesting queues."""
+ res = self.register_consumer(self.pgq_queue_name)
+ return res
+
+ def detach(self):
+ """Detach consumer from all queues."""
+ tmp = self.reg_list[:]
+ for q in tmp:
+ self.unregister_consumer(q)
+
+ def process_event(self, db, event):
+ """Process one event.
+
+ Should be overrided by user code.
+
+ Event should be tagged as done, retry or failed.
+ If not, it will be tagged as for retry.
+ """
+ raise Exception("needs to be implemented")
+
+ def process_batch(self, db, batch_id, event_list):
+ """Process all events in batch.
+
+ By default calls process_event for each.
+ Can be overrided by user code.
+
+ Events should be tagged as done, retry or failed.
+ If not, they will be tagged as for retry.
+ """
+ for ev in event_list:
+ self.process_event(db, ev)
+
+ def work(self):
+ """Do the work loop, once (internal)."""
+
+ if len(self.reg_list) == 0:
+ self.log.debug("Attaching")
+ self.attach()
+
+ db = self.get_database(self.db_name)
+ curs = db.cursor()
+
+ data_avail = 0
+ for queue in self.reg_list:
+ self.stat_start()
+
+ # acquire batch
+ batch_id = self._load_next_batch(curs, queue)
+ db.commit()
+ if batch_id == None:
+ continue
+ data_avail = 1
+
+ # load events
+ list = self._load_batch_events(curs, batch_id, queue)
+ db.commit()
+
+ # process events
+ self._launch_process_batch(db, batch_id, list)
+
+ # done
+ self._finish_batch(curs, batch_id, list)
+ db.commit()
+ self.stat_end(len(list))
+
+ # if false, script sleeps
+ return data_avail
+
+ def register_consumer(self, queue_name):
+ db = self.get_database(self.db_name)
+ cx = db.cursor()
+ cx.execute("select pgq.register_consumer(%s, %s)",
+ [queue_name, self.consumer_id])
+ res = cx.fetchone()[0]
+ db.commit()
+
+ self.reg_list.append(queue_name)
+
+ return res
+
+ def unregister_consumer(self, queue_name):
+ db = self.get_database(self.db_name)
+ cx = db.cursor()
+ cx.execute("select pgq.unregister_consumer(%s, %s)",
+ [queue_name, self.consumer_id])
+ db.commit()
+
+ self.reg_list.remove(queue_name)
+
+ def _launch_process_batch(self, db, batch_id, list):
+ self.process_batch(db, batch_id, list)
+
+ def _load_batch_events(self, curs, batch_id, queue_name):
+ """Fetch all events for this batch."""
+
+ # load events
+ sql = "select * from pgq.get_batch_events(%d)" % batch_id
+ curs.execute(sql)
+ rows = curs.dictfetchall()
+
+ # map them to python objects
+ list = []
+ for r in rows:
+ ev = Event(queue_name, r)
+ list.append(ev)
+
+ return list
+
+ def _load_next_batch(self, curs, queue_name):
+ """Allocate next batch. (internal)"""
+
+ q = "select pgq.next_batch(%s, %s)"
+ curs.execute(q, [queue_name, self.consumer_id])
+ return curs.fetchone()[0]
+
+ def _finish_batch(self, curs, batch_id, list):
+ """Tag events and notify that the batch is done."""
+
+ retry = failed = 0
+ for ev in list:
+ if ev.status == EV_FAILED:
+ self._tag_failed(curs, batch_id, ev)
+ failed += 1
+ elif ev.status == EV_RETRY:
+ self._tag_retry(curs, batch_id, ev)
+ retry += 1
+ curs.execute("select pgq.finish_batch(%s)", [batch_id])
+
+ def _tag_failed(self, curs, batch_id, ev):
+ """Tag event as failed. (internal)"""
+ curs.execute("select pgq.event_failed(%s, %s, %s)",
+ [batch_id, ev.id, ev.fail_reason])
+
+ def _tag_retry(self, cx, batch_id, ev):
+ """Tag event for retry. (internal)"""
+ cx.execute("select pgq.event_retry(%s, %s, %s)",
+ [batch_id, ev.id, ev.retry_time])
+
+ def get_batch_info(self, batch_id):
+ """Get info about batch.
+
+ @return: Return value is a dict of:
+
+ - queue_name: queue name
+ - consumer_name: consumers name
+ - batch_start: batch start time
+ - batch_end: batch end time
+ - tick_id: end tick id
+ - prev_tick_id: start tick id
+ - lag: how far is batch_end from current moment.
+ """
+ db = self.get_database(self.db_name)
+ cx = db.cursor()
+ q = "select queue_name, consumer_name, batch_start, batch_end,"\
+ " prev_tick_id, tick_id, lag"\
+ " from pgq.get_batch_info(%s)"
+ cx.execute(q, [batch_id])
+ row = cx.dictfetchone()
+ db.commit()
+ return row
+
+ def stat_start(self):
+ self.stat_batch_start = time.time()
+
+ def stat_end(self, count):
+ t = time.time()
+ self.stat_add('count', count)
+ self.stat_add('duration', t - self.stat_batch_start)
+
+
+class RemoteConsumer(Consumer):
+ """Helper for doing event processing in another database.
+
+ Requires that whole batch is processed in one TX.
+ """
+
+ def __init__(self, service_name, db_name, remote_db, args):
+ Consumer.__init__(self, service_name, db_name, args)
+ self.remote_db = remote_db
+
+ def process_batch(self, db, batch_id, event_list):
+ """Process all events in batch.
+
+ By default calls process_event for each.
+ """
+ dst_db = self.get_database(self.remote_db)
+ curs = dst_db.cursor()
+
+ if self.is_last_batch(curs, batch_id):
+ for ev in event_list:
+ ev.tag_done()
+ return
+
+ self.process_remote_batch(db, batch_id, event_list, dst_db)
+
+ self.set_last_batch(curs, batch_id)
+ dst_db.commit()
+
+ def is_last_batch(self, dst_curs, batch_id):
+ """Helper function to keep track of last successful batch
+ in external database.
+ """
+ q = "select pgq_ext.is_batch_done(%s, %s)"
+ dst_curs.execute(q, [ self.consumer_id, batch_id ])
+ return dst_curs.fetchone()[0]
+
+ def set_last_batch(self, dst_curs, batch_id):
+ """Helper function to set last successful batch
+ in external database.
+ """
+ q = "select pgq_ext.set_batch_done(%s, %s)"
+ dst_curs.execute(q, [ self.consumer_id, batch_id ])
+
+ def process_remote_batch(self, db, batch_id, event_list, dst_db):
+ raise Exception('process_remote_batch not implemented')
+
+class SerialConsumer(Consumer):
+ """Consumer that applies batches sequentially in second database.
+
+ Requirements:
+ - Whole batch in one TX.
+ - Must not use retry queue.
+
+ Features:
+ - Can detect if several batches are already applied to dest db.
+ - If some ticks are lost. allows to seek back on queue.
+ Whether it succeeds, depends on pgq configuration.
+ """
+
+ def __init__(self, service_name, db_name, remote_db, args):
+ Consumer.__init__(self, service_name, db_name, args)
+ self.remote_db = remote_db
+ self.dst_completed_table = "pgq_ext.completed_tick"
+ self.cur_batch_info = None
+
+ def startup(self):
+ if self.options.rewind:
+ self.rewind()
+ sys.exit(0)
+ if self.options.reset:
+ self.dst_reset()
+ sys.exit(0)
+ return Consumer.startup(self)
+
+ def init_optparse(self, parser = None):
+ p = Consumer.init_optparse(self, parser)
+ p.add_option("--rewind", action = "store_true",
+ help = "change queue position according to destination")
+ p.add_option("--reset", action = "store_true",
+ help = "reset queue pos on destination side")
+ return p
+
+ def process_batch(self, db, batch_id, event_list):
+ """Process all events in batch.
+ """
+
+ dst_db = self.get_database(self.remote_db)
+ curs = dst_db.cursor()
+
+ self.cur_batch_info = self.get_batch_info(batch_id)
+
+ # check if done
+ if self.is_batch_done(curs):
+ for ev in event_list:
+ ev.tag_done()
+ return
+
+ # actual work
+ self.process_remote_batch(db, batch_id, event_list, dst_db)
+
+ # make sure no retry events
+ for ev in event_list:
+ if ev.status == EV_RETRY:
+ raise Exception("SerialConsumer must not use retry queue")
+
+ # finish work
+ self.set_batch_done(curs)
+ dst_db.commit()
+
+ def is_batch_done(self, dst_curs):
+ """Helper function to keep track of last successful batch
+ in external database.
+ """
+
+ prev_tick = self.cur_batch_info['prev_tick_id']
+
+ q = "select last_tick_id from %s where consumer_id = %%s" % (
+ self.dst_completed_table ,)
+ dst_curs.execute(q, [self.consumer_id])
+ res = dst_curs.fetchone()
+
+ if not res or not res[0]:
+ # seems this consumer has not run yet against dst_db
+ return False
+ dst_tick = res[0]
+
+ if prev_tick == dst_tick:
+ # on track
+ return False
+
+ if prev_tick < dst_tick:
+ self.log.warning('Got tick %d, dst has %d - skipping' % (prev_tick, dst_tick))
+ return True
+ else:
+ self.log.error('Got tick %d, dst has %d - ticks lost' % (prev_tick, dst_tick))
+ raise Exception('Lost ticks')
+
+ def set_batch_done(self, dst_curs):
+ """Helper function to set last successful batch
+ in external database.
+ """
+ tick_id = self.cur_batch_info['tick_id']
+ q = "delete from %s where consumer_id = %%s; "\
+ "insert into %s (consumer_id, last_tick_id) values (%%s, %%s)" % (
+ self.dst_completed_table,
+ self.dst_completed_table)
+ dst_curs.execute(q, [ self.consumer_id,
+ self.consumer_id, tick_id ])
+
+ def attach(self):
+ new = Consumer.attach(self)
+ if new:
+ self.clean_completed_tick()
+
+ def detach(self):
+ """If detaching, also clean completed tick table on dest."""
+
+ Consumer.detach(self)
+ self.clean_completed_tick()
+
+ def clean_completed_tick(self):
+ self.log.info("removing completed tick from dst")
+ dst_db = self.get_database(self.remote_db)
+ dst_curs = dst_db.cursor()
+
+ q = "delete from %s where consumer_id = %%s" % (
+ self.dst_completed_table,)
+ dst_curs.execute(q, [self.consumer_id])
+ dst_db.commit()
+
+ def process_remote_batch(self, db, batch_id, event_list, dst_db):
+ raise Exception('process_remote_batch not implemented')
+
+ def rewind(self):
+ self.log.info("Rewinding queue")
+ src_db = self.get_database(self.db_name)
+ dst_db = self.get_database(self.remote_db)
+ src_curs = src_db.cursor()
+ dst_curs = dst_db.cursor()
+
+ q = "select last_tick_id from %s where consumer_id = %%s" % (
+ self.dst_completed_table,)
+ dst_curs.execute(q, [self.consumer_id])
+ row = dst_curs.fetchone()
+ if row:
+ dst_tick = row[0]
+ q = "select pgq.register_consumer(%s, %s, %s)"
+ src_curs.execute(q, [self.pgq_queue_name, self.consumer_id, dst_tick])
+ else:
+ self.log.warning('No tick found on dst side')
+
+ dst_db.commit()
+ src_db.commit()
+
+ def dst_reset(self):
+ self.log.info("Resetting queue tracking on dst side")
+ dst_db = self.get_database(self.remote_db)
+ dst_curs = dst_db.cursor()
+
+ q = "delete from %s where consumer_id = %%s" % (
+ self.dst_completed_table,)
+ dst_curs.execute(q, [self.consumer_id])
+ dst_db.commit()
+
+
diff --git a/python/pgq/event.py b/python/pgq/event.py
new file mode 100644
index 00000000..d7b2d7ee
--- /dev/null
+++ b/python/pgq/event.py
@@ -0,0 +1,60 @@
+
+"""PgQ event container.
+"""
+
+__all__ = ('EV_RETRY', 'EV_DONE', 'EV_FAILED', 'Event')
+
+# Event status codes
+EV_RETRY = 0
+EV_DONE = 1
+EV_FAILED = 2
+
+_fldmap = {
+ 'ev_id': 'ev_id',
+ 'ev_txid': 'ev_txid',
+ 'ev_time': 'ev_time',
+ 'ev_type': 'ev_type',
+ 'ev_data': 'ev_data',
+ 'ev_extra1': 'ev_extra1',
+ 'ev_extra2': 'ev_extra2',
+ 'ev_extra3': 'ev_extra3',
+ 'ev_extra4': 'ev_extra4',
+
+ 'id': 'ev_id',
+ 'txid': 'ev_txid',
+ 'time': 'ev_time',
+ 'type': 'ev_type',
+ 'data': 'ev_data',
+ 'extra1': 'ev_extra1',
+ 'extra2': 'ev_extra2',
+ 'extra3': 'ev_extra3',
+ 'extra4': 'ev_extra4',
+}
+
+class Event(object):
+ """Event data for consumers.
+
+ Consumer is supposed to tag them after processing.
+ If not, events will stay in retry queue.
+ """
+ def __init__(self, queue_name, row):
+ self._event_row = row
+ self.status = EV_RETRY
+ self.retry_time = 60
+ self.fail_reason = "Buggy consumer"
+ self.queue_name = queue_name
+
+ def __getattr__(self, key):
+ return self._event_row[_fldmap[key]]
+
+ def tag_done(self):
+ self.status = EV_DONE
+
+ def tag_retry(self, retry_time = 60):
+ self.status = EV_RETRY
+ self.retry_time = retry_time
+
+ def tag_failed(self, reason):
+ self.status = EV_FAILED
+ self.fail_reason = reason
+
diff --git a/python/pgq/maint.py b/python/pgq/maint.py
new file mode 100644
index 00000000..4636f74f
--- /dev/null
+++ b/python/pgq/maint.py
@@ -0,0 +1,99 @@
+"""PgQ maintenance functions."""
+
+import skytools, time
+
+def get_pgq_api_version(curs):
+ q = "select count(1) from pg_proc p, pg_namespace n"\
+ " where n.oid = p.pronamespace and n.nspname='pgq'"\
+ " and p.proname='version';"
+ curs.execute(q)
+ if not curs.fetchone()[0]:
+ return '1.0.0'
+
+ curs.execute("select pgq.version()")
+ return curs.fetchone()[0]
+
+def version_ge(curs, want_ver):
+ """Check is db version of pgq is greater than want_ver."""
+ db_ver = get_pgq_api_version(curs)
+ want_tuple = map(int, want_ver.split('.'))
+ db_tuple = map(int, db_ver.split('.'))
+ if db_tuple[0] != want_tuple[0]:
+ raise Exception('Wrong major version')
+ if db_tuple[1] >= want_tuple[1]:
+ return 1
+ return 0
+
+class MaintenanceJob(skytools.DBScript):
+ """Periodic maintenance."""
+ def __init__(self, ticker, args):
+ skytools.DBScript.__init__(self, 'pgqadm', args)
+ self.ticker = ticker
+ self.last_time = 0 # start immidiately
+ self.last_ticks = 0
+ self.clean_ticks = 1
+ self.maint_delay = 5*60
+
+ def startup(self):
+ # disable regular DBScript startup()
+ pass
+
+ def reload(self):
+ skytools.DBScript.reload(self)
+
+ # force loop_delay
+ self.loop_delay = 5
+
+ self.maint_delay = 60 * self.cf.getfloat('maint_delay_min', 5)
+ self.maint_delay = self.cf.getfloat('maint_delay', self.maint_delay)
+
+ def work(self):
+ t = time.time()
+ if self.last_time + self.maint_delay > t:
+ return
+
+ self.do_maintenance()
+
+ self.last_time = t
+ duration = time.time() - t
+ self.stat_add('maint_duration', duration)
+
+ def do_maintenance(self):
+ """Helper function for running maintenance."""
+
+ db = self.get_database('db', autocommit=1)
+ cx = db.cursor()
+
+ if skytools.exists_function(cx, "pgq.maint_rotate_tables_step1", 1):
+ # rotate each queue in own TX
+ q = "select queue_name from pgq.get_queue_info()"
+ cx.execute(q)
+ for row in cx.fetchall():
+ cx.execute("select pgq.maint_rotate_tables_step1(%s)", [row[0]])
+ res = cx.fetchone()[0]
+ if res:
+ self.log.info('Rotating %s' % row[0])
+ else:
+ cx.execute("select pgq.maint_rotate_tables_step1();")
+
+ # finish rotation
+ cx.execute("select pgq.maint_rotate_tables_step2();")
+
+ # move retry events to main queue in small blocks
+ rcount = 0
+ while 1:
+ cx.execute('select pgq.maint_retry_events();')
+ res = cx.fetchone()[0]
+ rcount += res
+ if res == 0:
+ break
+ if rcount:
+ self.log.info('Got %d events for retry' % rcount)
+
+ # vacuum tables that are needed
+ cx.execute('set maintenance_work_mem = 32768')
+ cx.execute('select * from pgq.maint_tables_to_vacuum()')
+ for row in cx.fetchall():
+ cx.execute('vacuum %s;' % row[0])
+
+
diff --git a/python/pgq/producer.py b/python/pgq/producer.py
new file mode 100644
index 00000000..81e1ca4f
--- /dev/null
+++ b/python/pgq/producer.py
@@ -0,0 +1,41 @@
+
+"""PgQ producer helpers for Python.
+"""
+
+import skytools
+
+_fldmap = {
+ 'id': 'ev_id',
+ 'time': 'ev_time',
+ 'type': 'ev_type',
+ 'data': 'ev_data',
+ 'extra1': 'ev_extra1',
+ 'extra2': 'ev_extra2',
+ 'extra3': 'ev_extra3',
+ 'extra4': 'ev_extra4',
+
+ 'ev_id': 'ev_id',
+ 'ev_time': 'ev_time',
+ 'ev_type': 'ev_type',
+ 'ev_data': 'ev_data',
+ 'ev_extra1': 'ev_extra1',
+ 'ev_extra2': 'ev_extra2',
+ 'ev_extra3': 'ev_extra3',
+ 'ev_extra4': 'ev_extra4',
+}
+
+def bulk_insert_events(curs, rows, fields, queue_name):
+ q = "select pgq.current_event_table(%s)"
+ curs.execute(q, [queue_name])
+ tbl = curs.fetchone()[0]
+ db_fields = map(_fldmap.get, fields)
+ skytools.magic_insert(curs, tbl, rows, db_fields)
+
+def insert_event(curs, queue, ev_type, ev_data,
+ extra1=None, extra2=None,
+ extra3=None, extra4=None):
+ q = "select pgq.insert_event(%s, %s, %s, %s, %s, %s, %s)"
+ curs.execute(q, [queue, ev_type, ev_data,
+ extra1, extra2, extra3, extra4])
+ return curs.fetchone()[0]
+
diff --git a/python/pgq/status.py b/python/pgq/status.py
new file mode 100644
index 00000000..2214045f
--- /dev/null
+++ b/python/pgq/status.py
@@ -0,0 +1,93 @@
+
+"""Status display.
+"""
+
+import sys, os, skytools
+
+def ival(data, as = None):
+ "Format interval for output"
+ if not as:
+ as = data.split('.')[-1]
+ numfmt = 'FM9999999'
+ expr = "coalesce(to_char(extract(epoch from %s), '%s') || 's', 'NULL') as %s"
+ return expr % (data, numfmt, as)
+
+class PGQStatus(skytools.DBScript):
+ def __init__(self, args, check = 0):
+ skytools.DBScript.__init__(self, 'pgqadm', args)
+
+ self.show_status()
+
+ sys.exit(0)
+
+ def show_status(self):
+ db = self.get_database("db", autocommit=1)
+ cx = db.cursor()
+
+ cx.execute("show server_version")
+ pgver = cx.fetchone()[0]
+ cx.execute("select pgq.version()")
+ qver = cx.fetchone()[0]
+ print "Postgres version: %s PgQ version: %s" % (pgver, qver)
+
+ q = """select f.queue_name, f.num_tables, %s, %s, %s,
+ q.queue_ticker_max_lag, q.queue_ticker_max_amount,
+ q.queue_ticker_idle_interval
+ from pgq.get_queue_info() f, pgq.queue q
+ where q.queue_name = f.queue_name""" % (
+ ival('f.rotation_delay'),
+ ival('f.ticker_lag'),
+ )
+ cx.execute(q)
+ event_rows = cx.dictfetchall()
+
+ q = """select queue_name, consumer_name, %s, %s, %s
+ from pgq.get_consumer_info()""" % (
+ ival('lag'),
+ ival('last_seen'),
+ )
+ cx.execute(q)
+ consumer_rows = cx.dictfetchall()
+
+ print "\n%-32s %s %9s %13s %6s" % ('Event queue',
+ 'Rotation', 'Ticker', 'TLag')
+ print '-' * 78
+ for ev_row in event_rows:
+ tck = "%s/%ss/%ss" % (ev_row['queue_ticker_max_amount'],
+ ev_row['queue_ticker_max_lag'],
+ ev_row['queue_ticker_idle_interval'])
+ rot = "%s/%s" % (ev_row['queue_ntables'], ev_row['queue_rotation_period'])
+ print "%-39s%7s %9s %13s %6s" % (
+ ev_row['queue_name'],
+ rot,
+ tck,
+ ev_row['ticker_lag'],
+ )
+ print '-' * 78
+ print "\n%-42s %9s %9s" % (
+ 'Consumer', 'Lag', 'LastSeen')
+ print '-' * 78
+ for ev_row in event_rows:
+ cons = self.pick_consumers(ev_row, consumer_rows)
+ self.show_queue(ev_row, cons)
+ print '-' * 78
+ db.commit()
+
+ def show_consumer(self, cons):
+ print " %-48s %9s %9s" % (
+ cons['consumer_name'],
+ cons['lag'], cons['last_seen'])
+ def show_queue(self, ev_row, consumer_rows):
+ print "%(queue_name)s:" % ev_row
+ for cons in consumer_rows:
+ self.show_consumer(cons)
+
+
+ def pick_consumers(self, ev_row, consumer_rows):
+ res = []
+ for con in consumer_rows:
+ if con['queue_name'] != ev_row['queue_name']:
+ continue
+ res.append(con)
+ return res
+
diff --git a/python/pgq/ticker.py b/python/pgq/ticker.py
new file mode 100644
index 00000000..c218eaf1
--- /dev/null
+++ b/python/pgq/ticker.py
@@ -0,0 +1,172 @@
+"""PgQ ticker.
+
+It will also launch maintenance job.
+"""
+
+import sys, os, time, threading
+import skytools
+
+from maint import MaintenanceJob
+
+__all__ = ['SmartTicker']
+
+def is_txid_sane(curs):
+ curs.execute("select get_current_txid()")
+ txid = curs.fetchone()[0]
+
+ # on 8.2 theres no such table
+ if not skytools.exists_table(curs, 'txid.epoch'):
+ return 1
+
+ curs.execute("select epoch, last_value from txid.epoch")
+ epoch, last_val = curs.fetchone()
+ stored_val = (epoch << 32) | last_val
+
+ if stored_val <= txid:
+ return 1
+ else:
+ return 0
+
+class QueueStatus(object):
+ def __init__(self, name):
+ self.queue_name = name
+ self.seq_name = None
+ self.idle_period = 60
+ self.max_lag = 3
+ self.max_count = 200
+ self.last_tick_time = 0
+ self.last_count = 0
+ self.quiet_count = 0
+
+ def set_data(self, row):
+ self.seq_name = row['queue_event_seq']
+ self.idle_period = row['queue_ticker_idle_period']
+ self.max_lag = row['queue_ticker_max_lag']
+ self.max_count = row['queue_ticker_max_count']
+
+ def need_tick(self, cur_count, cur_time):
+ # check if tick is needed
+ need_tick = 0
+ lag = cur_time - self.last_tick_time
+
+ if cur_count == self.last_count:
+ # totally idle database
+
+ # don't go immidiately to big delays, as seq grows before commit
+ if self.quiet_count < 5:
+ if lag >= self.max_lag:
+ need_tick = 1
+ self.quiet_count += 1
+ else:
+ if lag >= self.idle_period:
+ need_tick = 1
+ else:
+ self.quiet_count = 0
+ # somewhat loaded machine
+ if cur_count - self.last_count >= self.max_count:
+ need_tick = 1
+ elif lag >= self.max_lag:
+ need_tick = 1
+ if need_tick:
+ self.last_tick_time = cur_time
+ self.last_count = cur_count
+ return need_tick
+
+class SmartTicker(skytools.DBScript):
+ last_tick_event = 0
+ last_tick_time = 0
+ quiet_count = 0
+ tick_count = 0
+ maint_thread = None
+
+ def __init__(self, args):
+ skytools.DBScript.__init__(self, 'pgqadm', args)
+
+ self.ticker_log_time = 0
+ self.ticker_log_delay = 5*60
+ self.queue_map = {}
+ self.refresh_time = 0
+
+ def reload(self):
+ skytools.DBScript.reload(self)
+ self.ticker_log_delay = self.cf.getfloat("ticker_log_delay", 5*60)
+
+ def startup(self):
+ if self.maint_thread:
+ return
+
+ db = self.get_database("db", autocommit = 1)
+ cx = db.cursor()
+ ok = is_txid_sane(cx)
+ if not ok:
+ self.log.error('txid in bad state')
+ sys.exit(1)
+
+ self.maint_thread = MaintenanceJob(self, [self.cf.filename])
+ t = threading.Thread(name = 'maint_thread',
+ target = self.maint_thread.run)
+ t.setDaemon(1)
+ t.start()
+
+ def refresh_queues(self, cx):
+ q = "select queue_name, queue_event_seq, queue_ticker_idle_period,"\
+ " queue_ticker_max_lag, queue_ticker_max_count"\
+ " from pgq.queue"\
+ " where not queue_external_ticker"
+ cx.execute(q)
+ new_map = {}
+ data_list = []
+ from_list = []
+ for row in cx.dictfetchall():
+ queue_name = row['queue_name']
+ try:
+ que = self.queue_map[queue_name]
+ except KeyError, x:
+ que = QueueStatus(queue_name)
+ que.set_data(row)
+ new_map[queue_name] = que
+
+ p1 = "'%s', %s.last_value" % (queue_name, que.seq_name)
+ data_list.append(p1)
+ from_list.append(que.seq_name)
+
+ self.queue_map = new_map
+ self.seq_query = "select %s from %s" % (
+ ", ".join(data_list),
+ ", ".join(from_list))
+
+ if len(from_list) == 0:
+ self.seq_query = None
+
+ self.refresh_time = time.time()
+
+ def work(self):
+ db = self.get_database("db", autocommit = 1)
+ cx = db.cursor()
+
+ cur_time = time.time()
+
+ if cur_time >= self.refresh_time + 30:
+ self.refresh_queues(cx)
+
+ if not self.seq_query:
+ return
+
+ # now check seqs
+ cx.execute(self.seq_query)
+ res = cx.fetchone()
+ pos = 0
+ while pos < len(res):
+ id = res[pos]
+ val = res[pos + 1]
+ pos += 2
+ que = self.queue_map[id]
+ if que.need_tick(val, cur_time):
+ cx.execute("select pgq.ticker(%s)", [que.queue_name])
+ self.tick_count += 1
+
+ if cur_time > self.ticker_log_time + self.ticker_log_delay:
+ self.ticker_log_time = cur_time
+ self.stat_add('ticks', self.tick_count)
+ self.tick_count = 0
+
diff --git a/python/pgqadm.py b/python/pgqadm.py
new file mode 100755
index 00000000..78f513dc
--- /dev/null
+++ b/python/pgqadm.py
@@ -0,0 +1,162 @@
+#! /usr/bin/env python
+
+"""PgQ ticker and maintenance.
+"""
+
+import sys
+import skytools
+
+from pgq.ticker import SmartTicker
+from pgq.status import PGQStatus
+#from pgq.admin import PGQAdmin
+
+"""TODO:
+pgqadm ini check
+"""
+
+command_usage = """
+%prog [options] INI CMD [subcmd args]
+
+commands:
+ ticker start ticking & maintenance process
+
+ status show overview of queue health
+ check show problematic consumers
+
+ install install code into db
+ create QNAME create queue
+ drop QNAME drop queue
+ register QNAME CONS install code into db
+ unregister QNAME CONS install code into db
+ config QNAME [VAR=VAL] show or change queue config
+"""
+
+config_allowed_list = [
+ 'queue_ticker_max_lag', 'queue_ticker_max_amount',
+ 'queue_ticker_idle_interval', 'queue_rotation_period']
+
+class PGQAdmin(skytools.DBScript):
+ def __init__(self, args):
+ skytools.DBScript.__init__(self, 'pgqadm', args)
+ self.set_single_loop(1)
+
+ if len(self.args) < 2:
+ print "need command"
+ sys.exit(1)
+
+ int_cmds = {
+ 'create': self.create_queue,
+ 'drop': self.drop_queue,
+ 'register': self.register,
+ 'unregister': self.unregister,
+ 'install': self.installer,
+ 'config': self.change_config,
+ }
+
+ cmd = self.args[1]
+ if cmd == "ticker":
+ script = SmartTicker(args)
+ elif cmd == "status":
+ script = PGQStatus(args)
+ elif cmd == "check":
+ script = PGQStatus(args, check = 1)
+ elif cmd in int_cmds:
+ script = None
+ self.work = int_cmds[cmd]
+ else:
+ print "unknown command"
+ sys.exit(1)
+
+ if self.pidfile:
+ self.pidfile += ".admin"
+ self.run_script = script
+
+ def start(self):
+ if self.run_script:
+ self.run_script.start()
+ else:
+ skytools.DBScript.start(self)
+
+ def init_optparse(self, parser=None):
+ p = skytools.DBScript.init_optparse(self, parser)
+ p.set_usage(command_usage.strip())
+ return p
+
+ def installer(self):
+ objs = [
+ skytools.DBLanguage("plpgsql"),
+ skytools.DBLanguage("plpythonu"),
+ skytools.DBFunction("get_current_txid", 0, sql_file="txid.sql"),
+ skytools.DBSchema("pgq", sql_file="pgq.sql"),
+ ]
+
+ db = self.get_database('db')
+ curs = db.cursor()
+ skytools.db_install(curs, objs, self.log)
+ db.commit()
+
+ def create_queue(self):
+ qname = self.args[2]
+ self.log.info('Creating queue: %s' % qname)
+ self.exec_sql("select pgq.create_queue(%s)", [qname])
+
+ def drop_queue(self):
+ qname = self.args[2]
+ self.log.info('Dropping queue: %s' % qname)
+ self.exec_sql("select pgq.drop_queue(%s)", [qname])
+
+ def register(self):
+ qname = self.args[2]
+ cons = self.args[3]
+ self.log.info('Registering consumer %s on queue %s' % (cons, qname))
+ self.exec_sql("select pgq.register_consumer(%s, %s)", [qname, cons])
+
+ def unregister(self):
+ qname = self.args[2]
+ cons = self.args[3]
+ self.log.info('Unregistering consumer %s from queue %s' % (cons, qname))
+ self.exec_sql("select pgq.unregister_consumer(%s, %s)", [qname, cons])
+
+ def change_config(self):
+ qname = self.args[2]
+ if len(self.args) == 3:
+ self.show_config(qname)
+ return
+ alist = []
+ for el in self.args[3:]:
+ k, v = el.split('=')
+ if k not in config_allowed_list:
+ raise Exception('unknown config var: '+k)
+ expr = "%s=%s" % (k, skytools.quote_literal(v))
+ alist.append(expr)
+ self.log.info('Change queue %s config to: %s' % (qname, ", ".join(alist)))
+ sql = "update pgq.queue set %s where queue_name = %s" % (
+ ", ".join(alist), skytools.quote_literal(qname))
+ self.exec_sql(sql, [])
+
+ def exec_sql(self, q, args):
+ self.log.debug(q)
+ db = self.get_database('db')
+ curs = db.cursor()
+ curs.execute(q, args)
+ db.commit()
+
+ def show_config(self, qname):
+ klist = ",".join(config_allowed_list)
+ q = "select * from pgq.queue where queue_name = %s"
+ db = self.get_database('db')
+ curs = db.cursor()
+ curs.execute(q, [qname])
+ res = curs.dictfetchone()
+ db.commit()
+
+ print qname
+ for k in config_allowed_list:
+ print " %s=%s" % (k, res[k])
+
+if __name__ == '__main__':
+ script = PGQAdmin(sys.argv[1:])
+ script.start()
+
+
+
diff --git a/python/skytools/__init__.py b/python/skytools/__init__.py
new file mode 100644
index 00000000..ed2b39bc
--- /dev/null
+++ b/python/skytools/__init__.py
@@ -0,0 +1,10 @@
+
+"""Tools for Python database scripts."""
+
+from config import *
+from dbstruct import *
+from gzlog import *
+from quoting import *
+from scripting import *
+from sqltools import *
+
diff --git a/python/skytools/config.py b/python/skytools/config.py
new file mode 100644
index 00000000..de420322
--- /dev/null
+++ b/python/skytools/config.py
@@ -0,0 +1,139 @@
+
+"""Nicer config class."""
+
+import sys, os, ConfigParser, socket
+
+__all__ = ['Config']
+
+class Config(object):
+ """Bit improved ConfigParser.
+
+ Additional features:
+ - Remembers section.
+ - Acceps defaults in get() functions.
+ - List value support.
+ """
+ def __init__(self, main_section, filename, sane_config = 1):
+ """Initialize Config and read from file.
+
+ @param sane_config: chooses between ConfigParser/SafeConfigParser.
+ """
+ defs = {
+ 'job_name': main_section,
+ 'service_name': main_section,
+ 'host_name': socket.gethostname(),
+ }
+ if not os.path.isfile(filename):
+ raise Exception('Config file not found: '+filename)
+
+ self.filename = filename
+ self.sane_config = sane_config
+ if sane_config:
+ self.cf = ConfigParser.SafeConfigParser(defs)
+ else:
+ self.cf = ConfigParser.ConfigParser(defs)
+ self.cf.read(filename)
+ self.main_section = main_section
+ if not self.cf.has_section(main_section):
+ raise Exception("Wrong config file, no section '%s'"%main_section)
+
+ def reload(self):
+ """Re-reads config file."""
+ self.cf.read(self.filename)
+
+ def get(self, key, default=None):
+ """Reads string value, if not set then default."""
+ try:
+ return self.cf.get(self.main_section, key)
+ except ConfigParser.NoOptionError, det:
+ if default == None:
+ raise Exception("Config value not set: " + key)
+ return default
+
+ def getint(self, key, default=None):
+ """Reads int value, if not set then default."""
+ try:
+ return self.cf.getint(self.main_section, key)
+ except ConfigParser.NoOptionError, det:
+ if default == None:
+ raise Exception("Config value not set: " + key)
+ return default
+
+ def getboolean(self, key, default=None):
+ """Reads boolean value, if not set then default."""
+ try:
+ return self.cf.getboolean(self.main_section, key)
+ except ConfigParser.NoOptionError, det:
+ if default == None:
+ raise Exception("Config value not set: " + key)
+ return default
+
+ def getfloat(self, key, default=None):
+ """Reads float value, if not set then default."""
+ try:
+ return self.cf.getfloat(self.main_section, key)
+ except ConfigParser.NoOptionError, det:
+ if default == None:
+ raise Exception("Config value not set: " + key)
+ return default
+
+ def getlist(self, key, default=None):
+ """Reads comma-separated list from key."""
+ try:
+ s = self.cf.get(self.main_section, key).strip()
+ res = []
+ if not s:
+ return res
+ for v in s.split(","):
+ res.append(v.strip())
+ return res
+ except ConfigParser.NoOptionError, det:
+ if default == None:
+ raise Exception("Config value not set: " + key)
+ return default
+
+ def getfile(self, key, default=None):
+ """Reads filename from config.
+
+ In addition to reading string value, expands ~ to user directory.
+ """
+ fn = self.get(key, default)
+ if fn == "" or fn == "-":
+ return fn
+ # simulate that the cwd is script location
+ #path = os.path.dirname(sys.argv[0])
+ # seems bad idea, cwd should be cwd
+
+ fn = os.path.expanduser(fn)
+
+ return fn
+
+ def get_wildcard(self, key, values=[], default=None):
+ """Reads a wildcard property from conf and returns its string value, if not set then default."""
+
+ orig_key = key
+ keys = [key]
+
+ for wild in values:
+ key = key.replace('*', wild, 1)
+ keys.append(key)
+ keys.reverse()
+
+ for key in keys:
+ try:
+ return self.cf.get(self.main_section, key)
+ except ConfigParser.NoOptionError, det:
+ pass
+
+ if default == None:
+ raise Exception("Config value not set: " + orig_key)
+ return default
+
+ def sections(self):
+ """Returns list of sections in config file, excluding DEFAULT."""
+ return self.cf.sections()
+
+ def clone(self, main_section):
+ """Return new Config() instance with new main section on same config file."""
+ return Config(main_section, self.filename, self.sane_config)
+
diff --git a/python/skytools/dbstruct.py b/python/skytools/dbstruct.py
new file mode 100644
index 00000000..22333429
--- /dev/null
+++ b/python/skytools/dbstruct.py
@@ -0,0 +1,380 @@
+"""Find table structure and allow CREATE/DROP elements from it.
+"""
+
+import sys, re
+
+from sqltools import fq_name_parts, get_table_oid
+
+__all__ = ['TableStruct',
+ 'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER',
+ 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL']
+
+T_TABLE = 1 << 0
+T_CONSTRAINT = 1 << 1
+T_INDEX = 1 << 2
+T_TRIGGER = 1 << 3
+T_RULE = 1 << 4
+T_GRANT = 1 << 5
+T_OWNER = 1 << 6
+T_PKEY = 1 << 20 # special, one of constraints
+T_ALL = ( T_TABLE | T_CONSTRAINT | T_INDEX
+ | T_TRIGGER | T_RULE | T_GRANT | T_OWNER )
+
+#
+# Utility functions
+#
+
+def find_new_name(curs, name):
+ """Create new object name for case the old exists.
+
+ Needed when creating a new table besides old one.
+ """
+ # cut off previous numbers
+ m = re.search('_[0-9]+$', name)
+ if m:
+ name = name[:m.start()]
+
+ # now loop
+ for i in range(1, 1000):
+ tname = "%s_%d" % (name, i)
+ q = "select count(1) from pg_class where relname = %s"
+ curs.execute(q, [tname])
+ if curs.fetchone()[0] == 0:
+ return tname
+
+ # failed
+ raise Exception('find_new_name failed')
+
+def rx_replace(rx, sql, new_part):
+ """Find a regex match and replace that part with new_part."""
+ m = re.search(rx, sql, re.I)
+ if not m:
+ raise Exception('rx_replace failed')
+ p1 = sql[:m.start()]
+ p2 = sql[m.end():]
+ return p1 + new_part + p2
+
+#
+# Schema objects
+#
+
+class TElem(object):
+ """Keeps info about one metadata object."""
+ SQL = ""
+ type = 0
+ def get_create_sql(self, curs):
+ """Return SQL statement for creating or None if not supported."""
+ return None
+ def get_drop_sql(self, curs):
+ """Return SQL statement for dropping or None of not supported."""
+ return None
+
+class TConstraint(TElem):
+ """Info about constraint."""
+ type = T_CONSTRAINT
+ SQL = """
+ SELECT conname as name, pg_get_constraintdef(oid) as def, contype
+ FROM pg_constraint WHERE conrelid = %(oid)s
+ """
+ def __init__(self, table_name, row):
+ self.table_name = table_name
+ self.name = row['name']
+ self.defn = row['def']
+ self.contype = row['contype']
+
+ # tag pkeys
+ if self.contype == 'p':
+ self.type += T_PKEY
+
+ def get_create_sql(self, curs, new_table_name=None):
+ fmt = "ALTER TABLE ONLY %s ADD CONSTRAINT %s %s;"
+ if new_table_name:
+ name = self.name
+ if self.contype in ('p', 'u'):
+ name = find_new_name(curs, self.name)
+ sql = fmt % (new_table_name, name, self.defn)
+ else:
+ sql = fmt % (self.table_name, self.name, self.defn)
+ return sql
+
+ def get_drop_sql(self, curs):
+ fmt = "ALTER TABLE ONLY %s DROP CONSTRAINT %s;"
+ sql = fmt % (self.table_name, self.name)
+ return sql
+
+class TIndex(TElem):
+ """Info about index."""
+ type = T_INDEX
+ SQL = """
+ SELECT n.nspname || '.' || c.relname as name,
+ pg_get_indexdef(i.indexrelid) as defn
+ FROM pg_index i, pg_class c, pg_namespace n
+ WHERE c.oid = i.indexrelid AND i.indrelid = %(oid)s
+ AND n.oid = c.relnamespace
+ AND NOT EXISTS
+ (select objid from pg_depend
+ where classid = %(pg_class_oid)s
+ and objid = c.oid
+ and deptype = 'i')
+ """
+ def __init__(self, table_name, row):
+ self.name = row['name']
+ self.defn = row['defn'] + ';'
+
+ def get_create_sql(self, curs, new_table_name = None):
+ if not new_table_name:
+ return self.defn
+ name = find_new_name(curs, self.name)
+ pnew = "INDEX %s ON %s " % (name, new_table_name)
+ rx = r"\bINDEX[ ][a-z0-9._]+[ ]ON[ ][a-z0-9._]+[ ]"
+ sql = rx_replace(rx, self.defn, pnew)
+ return sql
+ def get_drop_sql(self, curs):
+ return 'DROP INDEX %s;' % self.name
+
+class TRule(TElem):
+ """Info about rule."""
+ type = T_RULE
+ SQL = """
+ SELECT rulename as name, pg_get_ruledef(oid) as def
+ FROM pg_rewrite
+ WHERE ev_class = %(oid)s AND rulename <> '_RETURN'::name
+ """
+ def __init__(self, table_name, row, new_name = None):
+ self.table_name = table_name
+ self.name = row['name']
+ self.defn = row['def']
+
+ def get_create_sql(self, curs, new_table_name = None):
+ if not new_table_name:
+ return self.defn
+ rx = r"\bTO[ ][a-z0-9._]+[ ]DO[ ]"
+ pnew = "TO %s DO " % new_table_name
+ return rx_replace(rx, self.defn, pnew)
+
+ def get_drop_sql(self, curs):
+ return 'DROP RULE %s ON %s' % (self.name, self.table_name)
+
+class TTrigger(TElem):
+ """Info about trigger."""
+ type = T_TRIGGER
+ SQL = """
+ SELECT tgname as name, pg_get_triggerdef(oid) as def
+ FROM pg_trigger
+ WHERE tgrelid = %(oid)s AND NOT tgisconstraint
+ """
+ def __init__(self, table_name, row):
+ self.table_name = table_name
+ self.name = row['name']
+ self.defn = row['def'] + ';'
+
+ def get_create_sql(self, curs, new_table_name = None):
+ if not new_table_name:
+ return self.defn
+
+ rx = r"\bON[ ][a-z0-9._]+[ ]"
+ pnew = "ON %s " % new_table_name
+ return rx_replace(rx, self.defn, pnew)
+
+ def get_drop_sql(self, curs):
+ return 'DROP TRIGGER %s ON %s' % (self.name, self.table_name)
+
+class TOwner(TElem):
+ """Info about table owner."""
+ type = T_OWNER
+ SQL = """
+ SELECT pg_get_userbyid(relowner) as owner FROM pg_class
+ WHERE oid = %(oid)s
+ """
+ def __init__(self, table_name, row, new_name = None):
+ self.table_name = table_name
+ self.name = 'Owner'
+ self.owner = row['owner']
+
+ def get_create_sql(self, curs, new_name = None):
+ if not new_name:
+ new_name = self.table_name
+ return 'ALTER TABLE %s OWNER TO %s;' % (new_name, self.owner)
+
+class TGrant(TElem):
+ """Info about permissions."""
+ type = T_GRANT
+ SQL = "SELECT relacl FROM pg_class where oid = %(oid)s"
+ acl_map = {
+ 'r': 'SELECT', 'w': 'UPDATE', 'a': 'INSERT', 'd': 'DELETE',
+ 'R': 'RULE', 'x': 'REFERENCES', 't': 'TRIGGER', 'X': 'EXECUTE',
+ 'U': 'USAGE', 'C': 'CREATE', 'T': 'TEMPORARY'
+ }
+ def acl_to_grants(self, acl):
+ if acl == "arwdRxt": # ALL for tables
+ return "ALL"
+ return ", ".join([ self.acl_map[c] for c in acl ])
+
+ def parse_relacl(self, relacl):
+ if relacl is None:
+ return []
+ if len(relacl) > 0 and relacl[0] == '{' and relacl[-1] == '}':
+ relacl = relacl[1:-1]
+ list = []
+ for f in relacl.split(','):
+ user, tmp = f.strip('"').split('=')
+ acl, who = tmp.split('/')
+ list.append((user, acl, who))
+ return list
+
+ def __init__(self, table_name, row, new_name = None):
+ self.name = table_name
+ self.acl_list = self.parse_relacl(row['relacl'])
+
+ def get_create_sql(self, curs, new_name = None):
+ if not new_name:
+ new_name = self.name
+
+ list = []
+ for user, acl, who in self.acl_list:
+ astr = self.acl_to_grants(acl)
+ sql = "GRANT %s ON %s TO %s;" % (astr, new_name, user)
+ list.append(sql)
+ return "\n".join(list)
+
+ def get_drop_sql(self, curs):
+ list = []
+ for user, acl, who in self.acl_list:
+ sql = "REVOKE ALL FROM %s ON %s;" % (user, self.name)
+ list.append(sql)
+ return "\n".join(list)
+
+class TColumn(TElem):
+ """Info about table column."""
+ SQL = """
+ select a.attname as name,
+ a.attname || ' '
+ || format_type(a.atttypid, a.atttypmod)
+ || case when a.attnotnull then ' not null' else '' end
+ || case when a.atthasdef then ' ' || d.adsrc else '' end
+ as def
+ from pg_attribute a left join pg_attrdef d
+ on (d.adrelid = a.attrelid and d.adnum = a.attnum)
+ where a.attrelid = %(oid)s
+ and not a.attisdropped
+ and a.attnum > 0
+ order by a.attnum;
+ """
+ def __init__(self, table_name, row):
+ self.name = row['name']
+ self.column_def = row['def']
+
+class TTable(TElem):
+ """Info about table only (columns)."""
+ type = T_TABLE
+ def __init__(self, table_name, col_list):
+ self.name = table_name
+ self.col_list = col_list
+
+ def get_create_sql(self, curs, new_name = None):
+ if not new_name:
+ new_name = self.name
+ sql = "create table %s (" % new_name
+ sep = "\n\t"
+ for c in self.col_list:
+ sql += sep + c.column_def
+ sep = ",\n\t"
+ sql += "\n);"
+ return sql
+
+ def get_drop_sql(self, curs):
+ return "DROP TABLE %s;" % self.name
+
+#
+# Main table object, loads all the others
+#
+
+class TableStruct(object):
+ """Collects and manages all info about table.
+
+ Allow to issue CREATE/DROP statements about any
+ group of elements.
+ """
+ def __init__(self, curs, table_name):
+ """Initializes class by loading info about table_name from database."""
+
+ self.table_name = table_name
+
+ # fill args
+ schema, name = fq_name_parts(table_name)
+ args = {
+ 'schema': schema,
+ 'table': name,
+ 'oid': get_table_oid(curs, table_name),
+ 'pg_class_oid': get_table_oid(curs, 'pg_catalog.pg_class'),
+ }
+
+ # load table struct
+ self.col_list = self._load_elem(curs, args, TColumn)
+ self.object_list = [ TTable(table_name, self.col_list) ]
+
+ # load additional objects
+ to_load = [TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner]
+ for eclass in to_load:
+ self.object_list += self._load_elem(curs, args, eclass)
+
+ def _load_elem(self, curs, args, eclass):
+ list = []
+ curs.execute(eclass.SQL % args)
+ for row in curs.dictfetchall():
+ list.append(eclass(self.table_name, row))
+ return list
+
+ def create(self, curs, objs, new_table_name = None, log = None):
+ """Issues CREATE statements for requested set of objects.
+
+ If new_table_name is giver, creates table under that name
+ and also tries to rename all indexes/constraints that conflict
+ with existing table.
+ """
+
+ for o in self.object_list:
+ if o.type & objs:
+ sql = o.get_create_sql(curs, new_table_name)
+ if not sql:
+ continue
+ if log:
+ log.info('Creating %s' % o.name)
+ log.debug(sql)
+ curs.execute(sql)
+
+ def drop(self, curs, objs, log = None):
+ """Issues DROP statements for requested set of objects."""
+ for o in self.object_list:
+ if o.type & objs:
+ sql = o.get_drop_sql(curs)
+ if not sql:
+ continue
+ if log:
+ log.info('Dropping %s' % o.name)
+ log.debug(sql)
+ curs.execute(sql)
+
+ def get_column_list(self):
+ """Returns list of column names the table has."""
+
+ res = []
+ for c in self.col_list:
+ res.append(c.name)
+ return res
+
+def test():
+ import psycopg
+ db = psycopg.connect("dbname=fooz")
+ curs = db.cursor()
+
+ s = TableStruct(curs, "public.data1")
+
+ s.drop(curs, T_ALL)
+ s.create(curs, T_ALL)
+ s.create(curs, T_ALL, "data1_new")
+ s.create(curs, T_PKEY)
+
+if __name__ == '__main__':
+ test()
+
diff --git a/python/skytools/gzlog.py b/python/skytools/gzlog.py
new file mode 100644
index 00000000..558e2813
--- /dev/null
+++ b/python/skytools/gzlog.py
@@ -0,0 +1,39 @@
+
+"""Atomic append of gzipped data.
+
+The point is - if several gzip streams are concated, they
+are read back as one whose stream.
+"""
+
+import gzip
+from cStringIO import StringIO
+
+__all__ = ['gzip_append']
+
+#
+# gzip storage
+#
+def gzip_append(filename, data, level = 6):
+ """Append a block of data to file with safety checks."""
+
+ # compress data
+ buf = StringIO()
+ g = gzip.GzipFile(fileobj = buf, compresslevel = level, mode = "w")
+ g.write(data)
+ g.close()
+ zdata = buf.getvalue()
+
+ # append, safely
+ f = open(filename, "a+", 0)
+ f.seek(0, 2)
+ pos = f.tell()
+ try:
+ f.write(zdata)
+ f.close()
+ except Exception, ex:
+ # rollback on error
+ f.seek(pos, 0)
+ f.truncate()
+ f.close()
+ raise ex
+
diff --git a/python/skytools/quoting.py b/python/skytools/quoting.py
new file mode 100644
index 00000000..96b0b022
--- /dev/null
+++ b/python/skytools/quoting.py
@@ -0,0 +1,156 @@
+# quoting.py
+
+"""Various helpers for string quoting/unquoting."""
+
+import psycopg, urllib, re
+
+#
+# SQL quoting
+#
+
+def quote_literal(s):
+ """Quote a literal value for SQL.
+
+ Surronds it with single-quotes.
+ """
+
+ if s == None:
+ return "null"
+ s = psycopg.QuotedString(str(s))
+ return str(s)
+
+def quote_copy(s):
+ """Quoting for copy command."""
+
+ if s == None:
+ return "\\N"
+ s = str(s)
+ s = s.replace("\\", "\\\\")
+ s = s.replace("\t", "\\t")
+ s = s.replace("\n", "\\n")
+ s = s.replace("\r", "\\r")
+ return s
+
+def quote_bytea_raw(s):
+ """Quoting for bytea parser."""
+
+ if s == None:
+ return None
+ return s.replace("\\", "\\\\").replace("\0", "\\000")
+
+def quote_bytea_literal(s):
+ """Quote bytea for regular SQL."""
+
+ return quote_literal(quote_bytea_raw(s))
+
+def quote_bytea_copy(s):
+ """Quote bytea for COPY."""
+
+ return quote_copy(quote_bytea_raw(s))
+
+def quote_statement(sql, dict):
+ """Quote whose statement.
+
+ Data values are taken from dict.
+ """
+ xdict = {}
+ for k, v in dict.items():
+ xdict[k] = quote_literal(v)
+ return sql % xdict
+
+#
+# quoting for JSON strings
+#
+
+_jsre = re.compile(r'[\x00-\x1F\\/"]')
+_jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r",
+ "\t": "\\t", "\\": "\\\\", '"': '\\"',
+ "/": "\\/", # to avoid html attacks
+}
+
+def _json_quote_char(m):
+ c = m.group(0)
+ try:
+ return _jsmap[c]
+ except KeyError:
+ return r"\u%04x" % ord(c)
+
+def quote_json(s):
+ """JSON style quoting."""
+ if s is None:
+ return "null"
+ return '"%s"' % _jsre.sub(_json_quote_char, s)
+
+#
+# Database specific urlencode and urldecode.
+#
+
+def db_urlencode(dict):
+ """Database specific urlencode.
+
+ Encode None as key without '='. That means that in "foo&bar=",
+ foo is NULL and bar is empty string.
+ """
+
+ elem_list = []
+ for k, v in dict.items():
+ if v is None:
+ elem = urllib.quote_plus(str(k))
+ else:
+ elem = urllib.quote_plus(str(k)) + '=' + urllib.quote_plus(str(v))
+ elem_list.append(elem)
+ return '&'.join(elem_list)
+
+def db_urldecode(qs):
+ """Database specific urldecode.
+
+ Decode key without '=' as None.
+ This also does not support one key several times.
+ """
+
+ res = {}
+ for elem in qs.split('&'):
+ if not elem:
+ continue
+ pair = elem.split('=', 1)
+ name = urllib.unquote_plus(pair[0])
+ if len(pair) == 1:
+ res[name] = None
+ else:
+ res[name] = urllib.unquote_plus(pair[1])
+ return res
+
+#
+# Remove C-like backslash escapes
+#
+
+_esc_re = r"\\([0-7][0-7][0-7]|.)"
+_esc_rc = re.compile(_esc_re)
+_esc_map = {
+ 't': '\t',
+ 'n': '\n',
+ 'r': '\r',
+ 'a': '\a',
+ 'b': '\b',
+ "'": "'",
+ '"': '"',
+ '\\': '\\',
+}
+
+def _sub_unescape(m):
+ v = m.group(1)
+ if len(v) == 1:
+ return _esc_map[v]
+ else:
+ return chr(int(v, 8))
+
+def unescape(val):
+ """Removes C-style escapes from string."""
+ return _esc_rc.sub(_sub_unescape, val)
+
+def unescape_copy(val):
+ """Removes C-style escapes, also converts "\N" to None."""
+ if val == r"\N":
+ return None
+ return unescape(val)
+
diff --git a/python/skytools/scripting.py b/python/skytools/scripting.py
new file mode 100644
index 00000000..cf976801
--- /dev/null
+++ b/python/skytools/scripting.py
@@ -0,0 +1,523 @@
+
+"""Useful functions and classes for database scripts."""
+
+import sys, os, signal, psycopg, optparse, traceback, time
+import logging, logging.handlers, logging.config
+
+from skytools.config import *
+import skytools.skylog
+
+__all__ = ['daemonize', 'run_single_process', 'DBScript',
+ 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_SERIALIZABLE']
+
+#
+# daemon mode
+#
+
+def daemonize():
+ """Turn the process into daemon.
+
+ Goes background and disables all i/o.
+ """
+
+ # launch new process, kill parent
+ pid = os.fork()
+ if pid != 0:
+ os._exit(0)
+
+ # start new session
+ os.setsid()
+
+ # stop i/o
+ fd = os.open("/dev/null", os.O_RDWR)
+ os.dup2(fd, 0)
+ os.dup2(fd, 1)
+ os.dup2(fd, 2)
+ if fd > 2:
+ os.close(fd)
+
+#
+# Pidfile locking+cleanup & daemonization combined
+#
+
+def _write_pidfile(pidfile):
+ pid = os.getpid()
+ f = open(pidfile, 'w')
+ f.write(str(pid))
+ f.close()
+
+def run_single_process(runnable, daemon, pidfile):
+ """Run runnable class, possibly daemonized, locked on pidfile."""
+
+ # check if another process is running
+ if pidfile and os.path.isfile(pidfile):
+ print "Pidfile exists, another process running?"
+ sys.exit(1)
+
+ # daemonize if needed and write pidfile
+ if daemon:
+ daemonize()
+ if pidfile:
+ _write_pidfile(pidfile)
+
+ # Catch SIGTERM to cleanup pidfile
+ def sigterm_hook(signum, frame):
+ try:
+ os.remove(pidfile)
+ except: pass
+ sys.exit(0)
+ # attach it to signal
+ if pidfile:
+ signal.signal(signal.SIGTERM, sigterm_hook)
+
+ # run
+ try:
+ runnable.run()
+ finally:
+ # another try of cleaning up
+ if pidfile:
+ try:
+ os.remove(pidfile)
+ except: pass
+
+#
+# logging setup
+#
+
+_log_config_done = 0
+_log_init_done = {}
+
+def _init_log(job_name, cf, log_level):
+ """Logging setup happens here."""
+ global _log_init_done, _log_config_done
+
+ got_skylog = 0
+ use_skylog = cf.getint("use_skylog", 0)
+
+ # load logging config if needed
+ if use_skylog and not _log_config_done:
+ # python logging.config braindamage:
+ # cannot specify external classess without such hack
+ logging.skylog = skytools.skylog
+
+ # load general config
+ list = ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']
+ for fn in list:
+ fn = os.path.expanduser(fn)
+ if os.path.isfile(fn):
+ defs = {'job_name': job_name}
+ logging.config.fileConfig(fn, defs)
+ got_skylog = 1
+ break
+ _log_config_done = 1
+ if not got_skylog:
+ sys.stderr.write("skylog.ini not found!\n")
+ sys.exit(1)
+
+ # avoid duplicate logging init for job_name
+ log = logging.getLogger(job_name)
+ if job_name in _log_init_done:
+ return log
+ _log_init_done[job_name] = 1
+
+ # compatibility: specify ini file in script config
+ logfile = cf.getfile("logfile", "")
+ if logfile:
+ fmt = logging.Formatter('%(asctime)s %(process)s %(levelname)s %(message)s')
+ size = cf.getint('log_size', 10*1024*1024)
+ num = cf.getint('log_count', 3)
+ hdlr = logging.handlers.RotatingFileHandler(
+ logfile, 'a', size, num)
+ hdlr.setFormatter(fmt)
+ log.addHandler(hdlr)
+
+ # if skylog.ini is disabled or not available, log at least to stderr
+ if not got_skylog:
+ hdlr = logging.StreamHandler()
+ fmt = logging.Formatter('%(asctime)s %(process)s %(levelname)s %(message)s')
+ hdlr.setFormatter(fmt)
+ log.addHandler(hdlr)
+
+ log.setLevel(log_level)
+
+ return log
+
+#: how old connections need to be closed
+DEF_CONN_AGE = 20*60 # 20 min
+
+#: isolation level not set
+I_DEFAULT = -1
+
+#: isolation level constant for AUTOCOMMIT
+I_AUTOCOMMIT = 0
+#: isolation level constant for READ COMMITTED
+I_READ_COMMITTED = 1
+#: isolation level constant for SERIALIZABLE
+I_SERIALIZABLE = 2
+
+class DBCachedConn(object):
+ """Cache a db connection."""
+ def __init__(self, name, loc, max_age = DEF_CONN_AGE):
+ self.name = name
+ self.loc = loc
+ self.conn = None
+ self.conn_time = 0
+ self.max_age = max_age
+ self.autocommit = -1
+ self.isolation_level = -1
+
+ def get_connection(self, autocommit = 0, isolation_level = -1):
+ # autocommit overrider isolation_level
+ if autocommit:
+ isolation_level = I_AUTOCOMMIT
+
+ # default isolation_level is READ COMMITTED
+ if isolation_level < 0:
+ isolation_level = I_READ_COMMITTED
+
+ # new conn?
+ if not self.conn:
+ self.isolation_level = isolation_level
+ self.conn = psycopg.connect(self.loc)
+
+ self.conn.set_isolation_level(isolation_level)
+ self.conn_time = time.time()
+ else:
+ if self.isolation_level != isolation_level:
+ raise Exception("Conflict in isolation_level")
+
+ # done
+ return self.conn
+
+ def refresh(self):
+ if not self.conn:
+ return
+ #for row in self.conn.notifies():
+ # if row[0].lower() == "reload":
+ # self.reset()
+ # return
+ if not self.max_age:
+ return
+ if time.time() - self.conn_time >= self.max_age:
+ self.reset()
+
+ def reset(self):
+ if not self.conn:
+ return
+
+ # drop reference
+ conn = self.conn
+ self.conn = None
+
+ if self.isolation_level == I_AUTOCOMMIT:
+ return
+
+ # rollback & close
+ try:
+ conn.rollback()
+ except: pass
+ try:
+ conn.close()
+ except: pass
+
+class DBScript(object):
+ """Base class for database scripts.
+
+ Handles logging, daemonizing, config, errors.
+ """
+ service_name = None
+ job_name = None
+ cf = None
+ log = None
+
+ def __init__(self, service_name, args):
+ """Script setup.
+
+ User class should override work() and optionally __init__(), startup(),
+ reload(), reset() and init_optparse().
+
+ NB: in case of daemon, the __init__() and startup()/work() will be
+ run in different processes. So nothing fancy should be done in __init__().
+
+ @param service_name: unique name for script.
+ It will be also default job_name, if not specified in config.
+ @param args: cmdline args (sys.argv[1:]), but can be overrided
+ """
+ self.service_name = service_name
+ self.db_cache = {}
+ self.go_daemon = 0
+ self.do_single_loop = 0
+ self.looping = 1
+ self.need_reload = 1
+ self.stat_dict = {}
+ self.log_level = logging.INFO
+ self.work_state = 1
+
+ # parse command line
+ parser = self.init_optparse()
+ self.options, self.args = parser.parse_args(args)
+
+ # check args
+ if self.options.daemon:
+ self.go_daemon = 1
+ if self.options.quiet:
+ self.log_level = logging.WARNING
+ if self.options.verbose:
+ self.log_level = logging.DEBUG
+ if len(self.args) < 1:
+ print "need config file"
+ sys.exit(1)
+ conf_file = self.args[0]
+
+ # load config
+ self.cf = Config(self.service_name, conf_file)
+ self.job_name = self.cf.get("job_name", self.service_name)
+ self.pidfile = self.cf.getfile("pidfile", '')
+
+ self.reload()
+
+ # init logging
+ self.log = _init_log(self.job_name, self.cf, self.log_level)
+
+ # send signal, if needed
+ if self.options.cmd == "kill":
+ self.send_signal(signal.SIGTERM)
+ elif self.options.cmd == "stop":
+ self.send_signal(signal.SIGINT)
+ elif self.options.cmd == "reload":
+ self.send_signal(signal.SIGHUP)
+
+ def init_optparse(self, parser = None):
+ """Initialize a OptionParser() instance that will be used to
+ parse command line arguments.
+
+ Note that it can be overrided both directions - either DBScript
+ will initialize a instance and passes to user code or user can
+ initialize and then pass to DBScript.init_optparse().
+
+ @param parser: optional OptionParser() instance,
+ where DBScript should attachs its own arguments.
+ @return: initialized OptionParser() instance.
+ """
+ if parser:
+ p = parser
+ else:
+ p = optparse.OptionParser()
+ p.set_usage("%prog [options] INI")
+ # generic options
+ p.add_option("-q", "--quiet", action="store_true",
+ help = "make program silent")
+ p.add_option("-v", "--verbose", action="store_true",
+ help = "make program verbose")
+ p.add_option("-d", "--daemon", action="store_true",
+ help = "go background")
+
+ # control options
+ g = optparse.OptionGroup(p, 'control running process')
+ g.add_option("-r", "--reload",
+ action="store_const", const="reload", dest="cmd",
+ help = "reload config (send SIGHUP)")
+ g.add_option("-s", "--stop",
+ action="store_const", const="stop", dest="cmd",
+ help = "stop program safely (send SIGINT)")
+ g.add_option("-k", "--kill",
+ action="store_const", const="kill", dest="cmd",
+ help = "kill program immidiately (send SIGTERM)")
+ p.add_option_group(g)
+
+ return p
+
+ def send_signal(self, sig):
+ if not self.pidfile:
+ self.log.warning("No pidfile in config, nothing todo")
+ sys.exit(0)
+ if not os.path.isfile(self.pidfile):
+ self.log.warning("No pidfile, process not running")
+ sys.exit(0)
+ pid = int(open(self.pidfile, "r").read())
+ os.kill(pid, sig)
+ sys.exit(0)
+
+ def set_single_loop(self, do_single_loop):
+ """Changes whether the script will loop or not."""
+ self.do_single_loop = do_single_loop
+
+ def start(self):
+ """This will launch main processing thread."""
+ if self.go_daemon:
+ if not self.pidfile:
+ self.log.error("Daemon needs pidfile")
+ sys.exit(1)
+ run_single_process(self, self.go_daemon, self.pidfile)
+
+ def stop(self):
+ """Safely stops processing loop."""
+ self.looping = 0
+
+ def reload(self):
+ "Reload config."
+ self.cf.reload()
+ self.loop_delay = self.cf.getfloat("loop_delay", 1.0)
+
+ def hook_sighup(self, sig, frame):
+ "Internal SIGHUP handler. Minimal code here."
+ self.need_reload = 1
+
+ def hook_sigint(self, sig, frame):
+ "Internal SIGINT handler. Minimal code here."
+ self.stop()
+
+ def stat_add(self, key, value):
+ self.stat_put(key, value)
+
+ def stat_put(self, key, value):
+ """Sets a stat value."""
+ self.stat_dict[key] = value
+
+ def stat_increase(self, key, increase = 1):
+ """Increases a stat value."""
+ if key in self.stat_dict:
+ self.stat_dict[key] += increase
+ else:
+ self.stat_dict[key] = increase
+
+ def send_stats(self):
+ "Send statistics to log."
+
+ res = []
+ for k, v in self.stat_dict.items():
+ res.append("%s: %s" % (k, str(v)))
+
+ if len(res) == 0:
+ return
+
+ logmsg = "{%s}" % ", ".join(res)
+ self.log.info(logmsg)
+ self.stat_dict = {}
+
+ def get_database(self, dbname, autocommit = 0, isolation_level = -1,
+ cache = None, max_age = DEF_CONN_AGE):
+ """Load cached database connection.
+
+ User must not store it permanently somewhere,
+ as all connections will be invalidated on reset.
+ """
+
+ if not cache:
+ cache = dbname
+ if cache in self.db_cache:
+ dbc = self.db_cache[cache]
+ else:
+ loc = self.cf.get(dbname)
+ dbc = DBCachedConn(cache, loc, max_age)
+ self.db_cache[cache] = dbc
+
+ return dbc.get_connection(autocommit, isolation_level)
+
+ def close_database(self, dbname):
+ """Explicitly close a cached connection.
+
+ Next call to get_database() will reconnect.
+ """
+ if dbname in self.db_cache:
+ dbc = self.db_cache[dbname]
+ dbc.reset()
+
+ def reset(self):
+ "Something bad happened, reset all connections."
+ for dbc in self.db_cache.values():
+ dbc.reset()
+ self.db_cache = {}
+
+ def run(self):
+ "Thread main loop."
+
+ # run startup, safely
+ try:
+ self.startup()
+ except KeyboardInterrupt, det:
+ raise
+ except SystemExit, det:
+ raise
+ except Exception, det:
+ exc, msg, tb = sys.exc_info()
+ self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % (
+ self.job_name, str(exc), str(msg).rstrip(),
+ str(tb), repr(traceback.format_tb(tb))))
+ del tb
+ self.reset()
+ sys.exit(1)
+
+ while self.looping:
+ # reload config, if needed
+ if self.need_reload:
+ self.reload()
+ self.need_reload = 0
+
+ # do some work
+ work = self.run_once()
+
+ # send stats that was added
+ self.send_stats()
+
+ # reconnect if needed
+ for dbc in self.db_cache.values():
+ dbc.refresh()
+
+ # exit if needed
+ if self.do_single_loop:
+ self.log.debug("Only single loop requested, exiting")
+ break
+
+ # remember work state
+ self.work_state = work
+ # should sleep?
+ if not work:
+ try:
+ time.sleep(self.loop_delay)
+ except Exception, d:
+ self.log.debug("sleep failed: "+str(d))
+ sys.exit(0)
+
+ def run_once(self):
+ "Run users work function, safely."
+ try:
+ return self.work()
+ except SystemExit, d:
+ self.send_stats()
+ self.log.info("got SystemExit(%s), exiting" % str(d))
+ self.reset()
+ raise d
+ except KeyboardInterrupt, d:
+ self.send_stats()
+ self.log.info("got KeyboardInterrupt, exiting")
+ self.reset()
+ sys.exit(1)
+ except Exception, d:
+ self.send_stats()
+ exc, msg, tb = sys.exc_info()
+ self.log.fatal("Job %s crashed: %s: '%s' (%s: %s)" % (
+ self.job_name, str(exc), str(msg).rstrip(),
+ str(tb), repr(traceback.format_tb(tb))))
+ del tb
+ self.reset()
+ if self.looping:
+ time.sleep(20)
+ return 1
+
+ def work(self):
+ "Here should user's processing happen."
+ raise Exception("Nothing implemented?")
+
+ def startup(self):
+ """Will be called just before entering main loop.
+
+ In case of daemon, if will be called in same process as work(),
+ unlike __init__().
+ """
+
+ # set signals
+ signal.signal(signal.SIGHUP, self.hook_sighup)
+ signal.signal(signal.SIGINT, self.hook_sigint)
+
+
diff --git a/python/skytools/skylog.py b/python/skytools/skylog.py
new file mode 100644
index 00000000..2f6344ae
--- /dev/null
+++ b/python/skytools/skylog.py
@@ -0,0 +1,173 @@
+"""Our log handlers for Python's logging package.
+"""
+
+import sys, os, time, socket, psycopg
+import logging, logging.handlers
+
+from quoting import quote_json
+
+# configurable file logger
+class EasyRotatingFileHandler(logging.handlers.RotatingFileHandler):
+ """Easier setup for RotatingFileHandler."""
+ def __init__(self, filename, maxBytes = 10*1024*1024, backupCount = 3):
+ """Args same as for RotatingFileHandler, but in filename '~' is expanded."""
+ fn = os.path.expanduser(filename)
+ logging.handlers.RotatingFileHandler.__init__(self, fn, maxBytes=maxBytes, backupCount=backupCount)
+
+# send JSON message over UDP
+class UdpLogServerHandler(logging.handlers.DatagramHandler):
+ """Sends log records over UDP to logserver in JSON format."""
+
+ # map logging levels to logserver levels
+ _level_map = {
+ logging.DEBUG : 'DEBUG',
+ logging.INFO : 'INFO',
+ logging.WARNING : 'WARN',
+ logging.ERROR : 'ERROR',
+ logging.CRITICAL: 'FATAL',
+ }
+
+ # JSON message template
+ _log_template = '{\n\t'\
+ '"logger": "skytools.UdpLogServer",\n\t'\
+ '"timestamp": %.0f,\n\t'\
+ '"level": "%s",\n\t'\
+ '"thread": null,\n\t'\
+ '"message": %s,\n\t'\
+ '"properties": {"application":"%s", "hostname":"%s"}\n'\
+ '}'
+
+ # cut longer msgs
+ MAXMSG = 1024
+
+ def makePickle(self, record):
+ """Create message in JSON format."""
+ # get & cut msg
+ msg = self.format(record)
+ if len(msg) > self.MAXMSG:
+ msg = msg[:self.MAXMSG]
+ txt_level = self._level_map.get(record.levelno, "ERROR")
+ pkt = self._log_template % (time.time()*1000, txt_level,
+ quote_json(msg), record.name, socket.gethostname())
+ return pkt
+
+class LogDBHandler(logging.handlers.SocketHandler):
+ """Sends log records into PostgreSQL server.
+
+ Additionally, does some statistics aggregating,
+ to avoid overloading log server.
+
+ It subclasses SocketHandler to get throtthling for
+ failed connections.
+ """
+
+ # map codes to string
+ _level_map = {
+ logging.DEBUG : 'DEBUG',
+ logging.INFO : 'INFO',
+ logging.WARNING : 'WARNING',
+ logging.ERROR : 'ERROR',
+ logging.CRITICAL: 'FATAL',
+ }
+
+ def __init__(self, connect_string):
+ """
+ Initializes the handler with a specific connection string.
+ """
+
+ logging.handlers.SocketHandler.__init__(self, None, None)
+ self.closeOnError = 1
+
+ self.connect_string = connect_string
+
+ self.stat_cache = {}
+ self.stat_flush_period = 60
+ # send first stat line immidiately
+ self.last_stat_flush = 0
+
+ def createSocket(self):
+ try:
+ logging.handlers.SocketHandler.createSocket(self)
+ except:
+ self.sock = self.makeSocket()
+
+ def makeSocket(self):
+ """Create server connection.
+ In this case its not socket but psycopg conection."""
+
+ db = psycopg.connect(self.connect_string)
+ db.autocommit(1)
+ return db
+
+ def emit(self, record):
+ """Process log record."""
+
+ # we do not want log debug messages
+ if record.levelno < logging.INFO:
+ return
+
+ try:
+ self.process_rec(record)
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except:
+ self.handleError(record)
+
+ def process_rec(self, record):
+ """Aggregate stats if needed, and send to logdb."""
+ # render msg
+ msg = self.format(record)
+
+ # dont want to send stats too ofter
+ if record.levelno == logging.INFO and msg and msg[0] == "{":
+ self.aggregate_stats(msg)
+ if time.time() - self.last_stat_flush >= self.stat_flush_period:
+ self.flush_stats(record.name)
+ return
+
+ if record.levelno < logging.INFO:
+ self.flush_stats(record.name)
+
+ # dont send more than one line
+ ln = msg.find('\n')
+ if ln > 0:
+ msg = msg[:ln]
+
+ txt_level = self._level_map.get(record.levelno, "ERROR")
+ self.send_to_logdb(record.name, txt_level, msg)
+
+ def aggregate_stats(self, msg):
+ """Sum stats together, to lessen load on logdb."""
+
+ msg = msg[1:-1]
+ for rec in msg.split(", "):
+ k, v = rec.split(": ")
+ agg = self.stat_cache.get(k, 0)
+ if v.find('.') >= 0:
+ agg += float(v)
+ else:
+ agg += int(v)
+ self.stat_cache[k] = agg
+
+ def flush_stats(self, service):
+ """Send awuired stats to logdb."""
+ res = []
+ for k, v in self.stat_cache.items():
+ res.append("%s: %s" % (k, str(v)))
+ if len(res) > 0:
+ logmsg = "{%s}" % ", ".join(res)
+ self.send_to_logdb(service, "INFO", logmsg)
+ self.stat_cache = {}
+ self.last_stat_flush = time.time()
+
+ def send_to_logdb(self, service, type, msg):
+ """Actual sending is done here."""
+
+ if self.sock is None:
+ self.createSocket()
+
+ if self.sock:
+ logcur = self.sock.cursor()
+ query = "select * from log.add(%s, %s, %s)"
+ logcur.execute(query, [type, service, msg])
+
diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py
new file mode 100644
index 00000000..75e209f1
--- /dev/null
+++ b/python/skytools/sqltools.py
@@ -0,0 +1,398 @@
+
+"""Database tools."""
+
+import os
+from cStringIO import StringIO
+from quoting import quote_copy, quote_literal
+
+#
+# Fully qualified table name
+#
+
+def fq_name_parts(tbl):
+ "Return fully qualified name parts."
+
+ tmp = tbl.split('.')
+ if len(tmp) == 1:
+ return ('public', tbl)
+ elif len(tmp) == 2:
+ return tmp
+ else:
+ raise Exception('Syntax error in table name:'+tbl)
+
+def fq_name(tbl):
+ "Return fully qualified name."
+ return '.'.join(fq_name_parts(tbl))
+
+#
+# info about table
+#
+def get_table_oid(curs, table_name):
+ schema, name = fq_name_parts(table_name)
+ q = """select c.oid from pg_namespace n, pg_class c
+ where c.relnamespace = n.oid
+ and n.nspname = %s and c.relname = %s"""
+ curs.execute(q, [schema, name])
+ res = curs.fetchall()
+ if len(res) == 0:
+ raise Exception('Table not found: '+table_name)
+ return res[0][0]
+
+def get_table_pkeys(curs, tbl):
+ oid = get_table_oid(curs, tbl)
+ q = "SELECT k.attname FROM pg_index i, pg_attribute k"\
+ " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\
+ " AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped"\
+ " ORDER BY k.attnum"
+ curs.execute(q, [oid])
+ return map(lambda x: x[0], curs.fetchall())
+
+def get_table_columns(curs, tbl):
+ oid = get_table_oid(curs, tbl)
+ q = "SELECT k.attname FROM pg_attribute k"\
+ " WHERE k.attrelid = %s"\
+ " AND k.attnum > 0 AND NOT k.attisdropped"\
+ " ORDER BY k.attnum"
+ curs.execute(q, [oid])
+ return map(lambda x: x[0], curs.fetchall())
+
+#
+# exist checks
+#
+def exists_schema(curs, schema):
+ q = "select count(1) from pg_namespace where nspname = %s"
+ curs.execute(q, [schema])
+ res = curs.fetchone()
+ return res[0]
+
+def exists_table(curs, table_name):
+ schema, name = fq_name_parts(table_name)
+ q = """select count(1) from pg_namespace n, pg_class c
+ where c.relnamespace = n.oid and c.relkind = 'r'
+ and n.nspname = %s and c.relname = %s"""
+ curs.execute(q, [schema, name])
+ res = curs.fetchone()
+ return res[0]
+
+def exists_type(curs, type_name):
+ schema, name = fq_name_parts(type_name)
+ q = """select count(1) from pg_namespace n, pg_type t
+ where t.typnamespace = n.oid
+ and n.nspname = %s and t.typname = %s"""
+ curs.execute(q, [schema, name])
+ res = curs.fetchone()
+ return res[0]
+
+def exists_function(curs, function_name, nargs):
+ # this does not check arg types, so may match several functions
+ schema, name = fq_name_parts(function_name)
+ q = """select count(1) from pg_namespace n, pg_proc p
+ where p.pronamespace = n.oid and p.pronargs = %s
+ and n.nspname = %s and p.proname = %s"""
+ curs.execute(q, [nargs, schema, name])
+ res = curs.fetchone()
+ return res[0]
+
+def exists_language(curs, lang_name):
+ q = """select count(1) from pg_language
+ where lanname = %s"""
+ curs.execute(q, [lang_name])
+ res = curs.fetchone()
+ return res[0]
+
+#
+# Support for PostgreSQL snapshot
+#
+
+class Snapshot(object):
+ "Represents a PostgreSQL snapshot."
+
+ def __init__(self, str):
+ "Create snapshot from string."
+
+ self.sn_str = str
+ tmp = str.split(':')
+ if len(tmp) != 3:
+ raise Exception('Unknown format for snapshot')
+ self.xmin = int(tmp[0])
+ self.xmax = int(tmp[1])
+ self.txid_list = []
+ if tmp[2] != "":
+ for s in tmp[2].split(','):
+ self.txid_list.append(int(s))
+
+ def contains(self, txid):
+ "Is txid visible in snapshot."
+
+ txid = int(txid)
+
+ if txid < self.xmin:
+ return True
+ if txid >= self.xmax:
+ return False
+ if txid in self.txid_list:
+ return False
+ return True
+
+#
+# Copy helpers
+#
+
+def _gen_dict_copy(tbl, row, fields):
+ tmp = []
+ for f in fields:
+ v = row[f]
+ tmp.append(quote_copy(v))
+ return "\t".join(tmp)
+
+def _gen_dict_insert(tbl, row, fields):
+ tmp = []
+ for f in fields:
+ v = row[f]
+ tmp.append(quote_literal(v))
+ fmt = "insert into %s (%s) values (%s);"
+ return fmt % (tbl, ",".join(fields), ",".join(tmp))
+
+def _gen_list_copy(tbl, row, fields):
+ tmp = []
+ for i in range(len(fields)):
+ v = row[i]
+ tmp.append(quote_copy(v))
+ return "\t".join(tmp)
+
+def _gen_list_insert(tbl, row, fields):
+ tmp = []
+ for i in range(len(fields)):
+ v = row[i]
+ tmp.append(quote_literal(v))
+ fmt = "insert into %s (%s) values (%s);"
+ return fmt % (tbl, ",".join(fields), ",".join(tmp))
+
+def magic_insert(curs, tablename, data, fields = None, use_insert = 0):
+ """Copy/insert a list of dict/list data to database.
+
+ If curs == None, then the copy or insert statements are returned
+ as string. For list of dict the field list is optional, as its
+ possible to guess them from dict keys.
+ """
+ if len(data) == 0:
+ return
+
+ # decide how to process
+ if type(data[0]) == type({}):
+ if fields == None:
+ fields = data[0].keys()
+ if use_insert:
+ row_func = _gen_dict_insert
+ else:
+ row_func = _gen_dict_copy
+ else:
+ if fields == None:
+ raise Exception("Non-dict data needs field list")
+ if use_insert:
+ row_func = _gen_list_insert
+ else:
+ row_func = _gen_list_copy
+
+ # init processing
+ buf = StringIO()
+ if curs == None and use_insert == 0:
+ fmt = "COPY %s (%s) FROM STDIN;\n"
+ buf.write(fmt % (tablename, ",".join(fields)))
+
+ # process data
+ for row in data:
+ buf.write(row_func(tablename, row, fields))
+ buf.write("\n")
+
+ # if user needs only string, return it
+ if curs == None:
+ if use_insert == 0:
+ buf.write("\\.\n")
+ return buf.getvalue()
+
+ # do the actual copy/inserts
+ if use_insert:
+ curs.execute(buf.getvalue())
+ else:
+ buf.seek(0)
+ hdr = "%s (%s)" % (tablename, ",".join(fields))
+ curs.copy_from(buf, hdr)
+
+def db_copy_from_dict(curs, tablename, dict_list, fields = None):
+ """Do a COPY FROM STDIN using list of dicts as source."""
+
+ if len(dict_list) == 0:
+ return
+
+ if fields == None:
+ fields = dict_list[0].keys()
+
+ buf = StringIO()
+ for dat in dict_list:
+ row = []
+ for k in fields:
+ row.append(quote_copy(dat[k]))
+ buf.write("\t".join(row))
+ buf.write("\n")
+
+ buf.seek(0)
+ hdr = "%s (%s)" % (tablename, ",".join(fields))
+
+ curs.copy_from(buf, hdr)
+
+def db_copy_from_list(curs, tablename, row_list, fields):
+ """Do a COPY FROM STDIN using list of lists as source."""
+
+ if len(row_list) == 0:
+ return
+
+ if fields == None or len(fields) == 0:
+ raise Exception('Need field list')
+
+ buf = StringIO()
+ for dat in row_list:
+ row = []
+ for i in range(len(fields)):
+ row.append(quote_copy(dat[i]))
+ buf.write("\t".join(row))
+ buf.write("\n")
+
+ buf.seek(0)
+ hdr = "%s (%s)" % (tablename, ",".join(fields))
+
+ curs.copy_from(buf, hdr)
+
+#
+# Full COPY of table from one db to another
+#
+
+class CopyPipe(object):
+ "Splits one big COPY to chunks."
+
+ def __init__(self, dstcurs, tablename, limit = 512*1024, cancel_func=None):
+ self.tablename = tablename
+ self.dstcurs = dstcurs
+ self.buf = StringIO()
+ self.limit = limit
+ self.cancel_func = None
+ self.total_rows = 0
+ self.total_bytes = 0
+
+ def write(self, data):
+ "New data from psycopg"
+
+ self.total_bytes += len(data)
+ self.total_rows += data.count("\n")
+
+ if self.buf.tell() >= self.limit:
+ pos = data.find('\n')
+ if pos >= 0:
+ # split at newline
+ p1 = data[:pos + 1]
+ p2 = data[pos + 1:]
+ self.buf.write(p1)
+ self.flush()
+
+ data = p2
+
+ self.buf.write(data)
+
+ def flush(self):
+ "Send data out."
+
+ if self.cancel_func:
+ self.cancel_func()
+
+ if self.buf.tell() > 0:
+ self.buf.seek(0)
+ self.dstcurs.copy_from(self.buf, self.tablename)
+ self.buf.seek(0)
+ self.buf.truncate()
+
+def full_copy(tablename, src_curs, dst_curs, column_list = []):
+ """COPY table from one db to another."""
+
+ if column_list:
+ hdr = "%s (%s)" % (tablename, ",".join(column_list))
+ else:
+ hdr = tablename
+ buf = CopyPipe(dst_curs, hdr)
+ src_curs.copy_to(buf, hdr)
+ buf.flush()
+
+ return (buf.total_bytes, buf.total_rows)
+
+
+#
+# SQL installer
+#
+
+class DBObject(object):
+ """Base class for installable DB objects."""
+ name = None
+ sql = None
+ sql_file = None
+ def __init__(self, name, sql = None, sql_file = None):
+ self.name = name
+ self.sql = sql
+ self.sql_file = sql_file
+ def get_sql(self):
+ if self.sql:
+ return self.sql
+ if self.sql_file:
+ if self.sql_file[0] == "/":
+ fn = self.sql_file
+ else:
+ contrib_list = [
+ "/opt/pgsql/share/contrib",
+ "/usr/share/postgresql/8.0/contrib",
+ "/usr/share/postgresql/8.0/contrib",
+ "/usr/share/postgresql/8.1/contrib",
+ "/usr/share/postgresql/8.2/contrib",
+ ]
+ for dir in contrib_list:
+ fn = os.path.join(dir, self.sql_file)
+ if os.path.isfile(fn):
+ return open(fn, "r").read()
+ raise Exception('File not found: '+self.sql_file)
+ raise Exception('object not defined')
+ def create(self, curs):
+ curs.execute(self.get_sql())
+
+class DBSchema(DBObject):
+ """Handles db schema."""
+ def exists(self, curs):
+ return exists_schema(curs, self.name)
+
+class DBTable(DBObject):
+ """Handles db table."""
+ def exists(self, curs):
+ return exists_table(curs, self.name)
+
+class DBFunction(DBObject):
+ """Handles db function."""
+ def __init__(self, name, nargs, sql = None, sql_file = None):
+ DBObject.__init__(self, name, sql, sql_file)
+ self.nargs = nargs
+ def exists(self, curs):
+ return exists_function(curs, self.name, self.nargs)
+
+class DBLanguage(DBObject):
+ """Handles db language."""
+ def __init__(self, name):
+ DBObject.__init__(self, name, sql = "create language %s" % name)
+ def exists(self, curs):
+ return exists_language(curs, self.name)
+
+def db_install(curs, list, log = None):
+ """Installs list of objects into db."""
+ for obj in list:
+ if not obj.exists(curs):
+ if log:
+ log.info('Installing %s' % obj.name)
+ obj.create(curs)
+ else:
+ if log:
+ log.info('%s is installed' % obj.name)
+
diff --git a/python/walmgr.py b/python/walmgr.py
new file mode 100755
index 00000000..8f43fd6d
--- /dev/null
+++ b/python/walmgr.py
@@ -0,0 +1,648 @@
+#! /usr/bin/env python
+
+"""WALShipping manager.
+
+walmgr [-n] COMMAND
+
+Master commands:
+ setup Configure PostgreSQL for WAL archiving
+ backup Copies all master data to slave
+ sync Copies in-progress WALs to slave
+ syncdaemon Daemon mode for regular syncing
+ stop Stop archiving - de-configure PostgreSQL
+
+Slave commands:
+ restore Stop postmaster, move new data dir to right
+ location and start postmaster in playback mode.
+ boot Stop playback, accept queries.
+ pause Just wait, don't play WAL-s
+ continue Start playing WAL-s again
+
+Internal commands:
+ xarchive archive one WAL file (master)
+ xrestore restore one WAL file (slave)
+
+Switches:
+ -n no action, just print commands
+"""
+
+import os, sys, skytools, getopt, re, signal, time, traceback
+
+MASTER = 1
+SLAVE = 0
+
+def usage(err):
+ if err > 0:
+ print >>sys.stderr, __doc__
+ else:
+ print __doc__
+ sys.exit(err)
+
+class WalMgr(skytools.DBScript):
+ def __init__(self, wtype, cf_file, not_really, internal = 0, go_daemon = 0):
+ self.not_really = not_really
+ self.pg_backup = 0
+
+ if wtype == MASTER:
+ service_name = "wal-master"
+ else:
+ service_name = "wal-slave"
+
+ if not os.path.isfile(cf_file):
+ print "Config not found:", cf_file
+ sys.exit(1)
+
+ if go_daemon:
+ s_args = ["-d", cf_file]
+ else:
+ s_args = [cf_file]
+
+ skytools.DBScript.__init__(self, service_name, s_args,
+ force_logfile = internal)
+
+ def pg_start_backup(self, code):
+ q = "select pg_start_backup('FullBackup')"
+ self.log.info("Execute SQL: %s; [%s]" % (q, self.cf.get("master_db")))
+ if self.not_really:
+ self.pg_backup = 1
+ return
+ db = self.get_database("master_db")
+ db.cursor().execute(q)
+ db.commit()
+ self.close_database("master_db")
+ self.pg_backup = 1
+
+ def pg_stop_backup(self):
+ if not self.pg_backup:
+ return
+
+ q = "select pg_stop_backup()"
+ self.log.debug("Execute SQL: %s; [%s]" % (q, self.cf.get("master_db")))
+ if self.not_really:
+ return
+ db = self.get_database("master_db")
+ db.cursor().execute(q)
+ db.commit()
+ self.close_database("master_db")
+
+ def signal_postmaster(self, data_dir, sgn):
+ pidfile = os.path.join(data_dir, "postmaster.pid")
+ if not os.path.isfile(pidfile):
+ self.log.info("postmaster is not running")
+ return
+ buf = open(pidfile, "r").readline()
+ pid = int(buf.strip())
+ self.log.debug("Signal %d to process %d" % (sgn, pid))
+ if not self.not_really:
+ os.kill(pid, sgn)
+
+ def exec_big_rsync(self, cmdline):
+ cmd = "' '".join(cmdline)
+ self.log.debug("Execute big rsync cmd: '%s'" % (cmd))
+ if self.not_really:
+ return
+ res = os.spawnvp(os.P_WAIT, cmdline[0], cmdline)
+ if res == 24:
+ self.log.info("Some files vanished, but thats OK")
+ elif res != 0:
+ self.log.fatal("exec failed, res=%d" % res)
+ self.pg_stop_backup()
+ sys.exit(1)
+
+ def exec_cmd(self, cmdline):
+ cmd = "' '".join(cmdline)
+ self.log.debug("Execute cmd: '%s'" % (cmd))
+ if self.not_really:
+ return
+ res = os.spawnvp(os.P_WAIT, cmdline[0], cmdline)
+ if res != 0:
+ self.log.fatal("exec failed, res=%d" % res)
+ sys.exit(1)
+
+ def chdir(self, loc):
+ self.log.debug("chdir: '%s'" % (loc))
+ if self.not_really:
+ return
+ try:
+ os.chdir(loc)
+ except os.error:
+ self.log.fatal("CHDir failed")
+ self.pg_stop_backup()
+ sys.exit(1)
+
+ def get_last_complete(self):
+ """Get the name of last xarchived segment."""
+
+ data_dir = self.cf.get("master_data")
+ fn = os.path.join(data_dir, ".walshipping.last")
+ try:
+ last = open(fn, "r").read().strip()
+ return last
+ except:
+ self.log.info("Failed to read %s" % fn)
+ return None
+
+ def set_last_complete(self, last):
+ """Set the name of last xarchived segment."""
+
+ data_dir = self.cf.get("master_data")
+ fn = os.path.join(data_dir, ".walshipping.last")
+ fn_tmp = fn + ".new"
+ try:
+ f = open(fn_tmp, "w")
+ f.write(last)
+ f.close()
+ os.rename(fn_tmp, fn)
+ except:
+ self.log.fatal("Cannot write to %s" % fn)
+
+ def master_setup(self):
+ self.log.info("Configuring WAL archiving")
+
+ script = os.path.abspath(sys.argv[0])
+ cf_file = os.path.abspath(self.cf.filename)
+ cf_val = "%s %s %s" % (script, cf_file, "xarchive %p %f")
+
+ self.master_configure_archiving(cf_val)
+
+ def master_stop(self):
+ self.log.info("Disabling WAL archiving")
+
+ self.master_configure_archiving('')
+
+ def master_configure_archiving(self, cf_val):
+ cf_file = self.cf.get("master_config")
+ data_dir = self.cf.get("master_data")
+ r_active = re.compile("^[ ]*archive_command[ ]*=[ ]*'(.*)'.*$", re.M)
+ r_disabled = re.compile("^.*archive_command.*$", re.M)
+
+ cf_full = "archive_command = '%s'" % cf_val
+
+ if not os.path.isfile(cf_file):
+ self.log.fatal("Config file not found: %s" % cf_file)
+ self.log.info("Using config file: %s", cf_file)
+
+ buf = open(cf_file, "r").read()
+ m = r_active.search(buf)
+ if m:
+ old_val = m.group(1)
+ if old_val == cf_val:
+ self.log.debug("postmaster already configured")
+ else:
+ self.log.debug("found active but different conf")
+ newbuf = "%s%s%s" % (buf[:m.start()], cf_full, buf[m.end():])
+ self.change_config(cf_file, newbuf)
+ else:
+ m = r_disabled.search(buf)
+ if m:
+ self.log.debug("found disabled value")
+ newbuf = "%s\n%s%s" % (buf[:m.end()], cf_full, buf[m.end():])
+ self.change_config(cf_file, newbuf)
+ else:
+ self.log.debug("found no value")
+ newbuf = "%s\n%s\n\n" % (buf, cf_full)
+ self.change_config(cf_file, newbuf)
+
+ self.log.info("Sending SIGHUP to postmaster")
+ self.signal_postmaster(data_dir, signal.SIGHUP)
+ self.log.info("Done")
+
+ def change_config(self, cf_file, buf):
+ cf_old = cf_file + ".old"
+ cf_new = cf_file + ".new"
+
+ if self.not_really:
+ cf_new = "/tmp/postgresql.conf.new"
+ open(cf_new, "w").write(buf)
+ self.log.info("Showing diff")
+ os.system("diff -u %s %s" % (cf_file, cf_new))
+ self.log.info("Done diff")
+ os.remove(cf_new)
+ return
+
+ # polite method does not work, as usually not enough perms for it
+ if 0:
+ open(cf_new, "w").write(buf)
+ bak = open(cf_file, "r").read()
+ open(cf_old, "w").write(bak)
+ os.rename(cf_new, cf_file)
+ else:
+ open(cf_file, "w").write(buf)
+
+ def remote_mkdir(self, remdir):
+ tmp = remdir.split(":", 1)
+ if len(tmp) != 2:
+ raise Exception("cannot find hostname")
+ host, path = tmp
+ cmdline = ["ssh", host, "mkdir", "-p", path]
+ self.exec_cmd(cmdline)
+
+ def master_backup(self):
+ """Copy master data directory to slave."""
+
+ data_dir = self.cf.get("master_data")
+ dst_loc = self.cf.get("full_backup")
+ if dst_loc[-1] != "/":
+ dst_loc += "/"
+
+ self.pg_start_backup("FullBackup")
+
+ master_spc_dir = os.path.join(data_dir, "pg_tblspc")
+ slave_spc_dir = dst_loc + "tmpspc"
+
+ # copy data
+ self.chdir(data_dir)
+ cmdline = ["rsync", "-a", "--delete",
+ "--exclude", ".*",
+ "--exclude", "*.pid",
+ "--exclude", "*.opts",
+ "--exclude", "*.conf",
+ "--exclude", "*.conf.*",
+ "--exclude", "pg_xlog",
+ ".", dst_loc]
+ self.exec_big_rsync(cmdline)
+
+ # copy tblspc first, to test
+ if os.path.isdir(master_spc_dir):
+ self.log.info("Checking tablespaces")
+ list = os.listdir(master_spc_dir)
+ if len(list) > 0:
+ self.remote_mkdir(slave_spc_dir)
+ for tblspc in list:
+ if tblspc[0] == ".":
+ continue
+ tfn = os.path.join(master_spc_dir, tblspc)
+ if not os.path.islink(tfn):
+ self.log.info("Suspicious pg_tblspc entry: "+tblspc)
+ continue
+ spc_path = os.path.realpath(tfn)
+ self.log.info("Got tablespace %s: %s" % (tblspc, spc_path))
+ dstfn = slave_spc_dir + "/" + tblspc
+
+ try:
+ os.chdir(spc_path)
+ except Exception, det:
+ self.log.warning("Broken link:" + str(det))
+ continue
+ cmdline = ["rsync", "-a", "--delete",
+ "--exclude", ".*",
+ ".", dstfn]
+ self.exec_big_rsync(cmdline)
+
+ # copy pg_xlog
+ self.chdir(data_dir)
+ cmdline = ["rsync", "-a",
+ "--exclude", "*.done",
+ "--exclude", "*.backup",
+ "--delete", "pg_xlog", dst_loc]
+ self.exec_big_rsync(cmdline)
+
+ self.pg_stop_backup()
+
+ self.log.info("Full backup successful")
+
+ def master_xarchive(self, srcpath, srcname):
+ """Copy a complete WAL segment to slave."""
+
+ start_time = time.time()
+ self.log.debug("%s: start copy", srcname)
+
+ self.set_last_complete(srcname)
+
+ dst_loc = self.cf.get("completed_wals")
+ if dst_loc[-1] != "/":
+ dst_loc += "/"
+
+ # copy data
+ cmdline = ["rsync", "-t", srcpath, dst_loc]
+ self.exec_cmd(cmdline)
+
+ self.log.debug("%s: done", srcname)
+ end_time = time.time()
+ self.stat_add('count', 1)
+ self.stat_add('duration', end_time - start_time)
+
+ def master_sync(self):
+ """Copy partial WAL segments."""
+
+ data_dir = self.cf.get("master_data")
+ xlog_dir = os.path.join(data_dir, "pg_xlog")
+ dst_loc = self.cf.get("partial_wals")
+ if dst_loc[-1] != "/":
+ dst_loc += "/"
+
+ files = os.listdir(xlog_dir)
+ files.sort()
+
+ last = self.get_last_complete()
+ if last:
+ self.log.info("%s: last complete" % last)
+ else:
+ self.log.info("last complete not found, copying all")
+
+ for fn in files:
+ # check if interesting file
+ if len(fn) < 10:
+ continue
+ if fn[0] < "0" or fn[0] > '9':
+ continue
+ if fn.find(".") > 0:
+ continue
+ # check if to old
+ if last:
+ dot = last.find(".")
+ if dot > 0:
+ xlast = last[:dot]
+ if fn < xlast:
+ continue
+ else:
+ if fn <= last:
+ continue
+
+ # got interesting WAL
+ xlog = os.path.join(xlog_dir, fn)
+ # copy data
+ cmdline = ["rsync", "-t", xlog, dst_loc]
+ self.exec_cmd(cmdline)
+
+ self.log.info("Partial copy done")
+
+ def slave_xrestore(self, srcname, dstpath):
+ loop = 1
+ while loop:
+ try:
+ self.slave_xrestore_unsafe(srcname, dstpath)
+ loop = 0
+ except SystemExit, d:
+ sys.exit(1)
+ except Exception, d:
+ exc, msg, tb = sys.exc_info()
+ self.log.fatal("xrestore %s crashed: %s: '%s' (%s: %s)" % (
+ srcname, str(exc), str(msg).rstrip(),
+ str(tb), repr(traceback.format_tb(tb))))
+ time.sleep(10)
+ self.log.info("Re-exec: %s", repr(sys.argv))
+ os.execv(sys.argv[0], sys.argv)
+
+ def slave_xrestore_unsafe(self, srcname, dstpath):
+ srcdir = self.cf.get("completed_wals")
+ partdir = self.cf.get("partial_wals")
+ keep_old_logs = self.cf.getint("keep_old_logs", 0)
+ pausefile = os.path.join(srcdir, "PAUSE")
+ stopfile = os.path.join(srcdir, "STOP")
+ srcfile = os.path.join(srcdir, srcname)
+ partfile = os.path.join(partdir, srcname)
+
+ # loop until srcfile or stopfile appears
+ while 1:
+ if os.path.isfile(pausefile):
+ self.log.info("pause requested, sleeping")
+ time.sleep(20)
+ continue
+
+ if os.path.isfile(srcfile):
+ self.log.info("%s: Found" % srcname)
+ break
+
+ # ignore .history files
+ unused, ext = os.path.splitext(srcname)
+ if ext == ".history":
+ self.log.info("%s: not found, ignoring" % srcname)
+ sys.exit(1)
+
+ # if stopping, include also partial wals
+ if os.path.isfile(stopfile):
+ if os.path.isfile(partfile):
+ self.log.info("%s: found partial" % srcname)
+ srcfile = partfile
+ break
+ else:
+ self.log.info("%s: not found, stopping" % srcname)
+ sys.exit(1)
+
+ # nothing to do, sleep
+ self.log.debug("%s: not found, sleeping" % srcname)
+ time.sleep(20)
+
+ # got one, copy it
+ cmdline = ["cp", srcfile, dstpath]
+ self.exec_cmd(cmdline)
+
+ self.log.debug("%s: copy done, cleanup" % srcname)
+ self.slave_cleanup(srcname)
+
+ # it would be nice to have apply time too
+ self.stat_add('count', 1)
+
+ def slave_startup(self):
+ data_dir = self.cf.get("slave_data")
+ full_dir = self.cf.get("full_backup")
+ stop_cmd = self.cf.get("slave_stop_cmd", "")
+ start_cmd = self.cf.get("slave_start_cmd")
+ pidfile = os.path.join(data_dir, "postmaster.pid")
+
+ # stop postmaster if ordered
+ if stop_cmd and os.path.isfile(pidfile):
+ self.log.info("Stopping postmaster: " + stop_cmd)
+ if not self.not_really:
+ os.system(stop_cmd)
+ time.sleep(3)
+
+ # is it dead?
+ if os.path.isfile(pidfile):
+ self.log.fatal("Postmaster still running. Cannot continue.")
+ sys.exit(1)
+
+ # find name for data backup
+ i = 0
+ while 1:
+ bak = "%s.%d" % (data_dir, i)
+ if not os.path.isdir(bak):
+ break
+ i += 1
+
+ # move old data away
+ if os.path.isdir(data_dir):
+ self.log.info("Move %s to %s" % (data_dir, bak))
+ if not self.not_really:
+ os.rename(data_dir, bak)
+
+ # move new data
+ self.log.info("Move %s to %s" % (full_dir, data_dir))
+ if not self.not_really:
+ os.rename(full_dir, data_dir)
+ else:
+ data_dir = full_dir
+
+ # re-link tablespaces
+ spc_dir = os.path.join(data_dir, "pg_tblspc")
+ tmp_dir = os.path.join(data_dir, "tmpspc")
+ if os.path.isdir(spc_dir) and os.path.isdir(tmp_dir):
+ self.log.info("Linking tablespaces to temporary location")
+
+ # don't look into spc_dir, thus allowing
+ # user to move them before. re-link only those
+ # that are still in tmp_dir
+ list = os.listdir(tmp_dir)
+ list.sort()
+
+ for d in list:
+ if d[0] == ".":
+ continue
+ link_loc = os.path.join(spc_dir, d)
+ link_dst = os.path.join(tmp_dir, d)
+ self.log.info("Linking tablespace %s to %s" % (d, link_dst))
+ if not self.not_really:
+ if os.path.islink(link_loc):
+ os.remove(link_loc)
+ os.symlink(link_dst, link_loc)
+
+ # write recovery.conf
+ rconf = os.path.join(data_dir, "recovery.conf")
+ script = os.path.abspath(sys.argv[0])
+ cf_file = os.path.abspath(self.cf.filename)
+ conf = "\nrestore_command = '%s %s %s'\n" % (
+ script, cf_file, 'xrestore %f "%p"')
+ self.log.info("Write %s" % rconf)
+ if self.not_really:
+ print conf
+ else:
+ f = open(rconf, "w")
+ f.write(conf)
+ f.close()
+
+ # remove stopfile
+ srcdir = self.cf.get("completed_wals")
+ stopfile = os.path.join(srcdir, "STOP")
+ if os.path.isfile(stopfile):
+ self.log.info("Removing stopfile: "+stopfile)
+ if not self.not_really:
+ os.remove(stopfile)
+
+ # run database in recovery mode
+ self.log.info("Starting postmaster: " + start_cmd)
+ if not self.not_really:
+ os.system(start_cmd)
+
+ def slave_boot(self):
+ srcdir = self.cf.get("completed_wals")
+ stopfile = os.path.join(srcdir, "STOP")
+ open(stopfile, "w").write("1")
+ self.log.info("Stopping recovery mode")
+
+ def slave_pause(self):
+ srcdir = self.cf.get("completed_wals")
+ pausefile = os.path.join(srcdir, "PAUSE")
+ open(pausefile, "w").write("1")
+ self.log.info("Pausing recovery mode")
+
+ def slave_continue(self):
+ srcdir = self.cf.get("completed_wals")
+ pausefile = os.path.join(srcdir, "PAUSE")
+ if os.path.isfile(pausefile):
+ os.remove(pausefile)
+ self.log.info("Continuing with recovery")
+ else:
+ self.log.info("Recovery not paused?")
+
+ def slave_cleanup(self, last_applied):
+ completed_wals = self.cf.get("completed_wals")
+ partial_wals = self.cf.get("partial_wals")
+
+ self.log.debug("cleaning completed wals since %s" % last_applied)
+ last = self.del_wals(completed_wals, last_applied)
+ if last:
+ if os.path.isdir(partial_wals):
+ self.log.debug("cleaning partial wals since %s" % last)
+ self.del_wals(partial_wals, last)
+ else:
+ self.log.warning("partial_wals dir does not exist: %s"
+ % partial_wals)
+ self.log.debug("cleaning done")
+
+ def del_wals(self, path, last):
+ dot = last.find(".")
+ if dot > 0:
+ last = last[:dot]
+ list = os.listdir(path)
+ list.sort()
+ cur_last = None
+ n = len(list)
+ for i in range(n):
+ fname = list[i]
+ full = os.path.join(path, fname)
+ if fname[0] < "0" or fname[0] > "9":
+ continue
+
+ ok_del = 0
+ if fname < last:
+ self.log.debug("deleting %s" % full)
+ os.remove(full)
+ cur_last = fname
+ return cur_last
+
+ def work(self):
+ self.master_sync()
+
+def main():
+ try:
+ opts, args = getopt.getopt(sys.argv[1:], "nh")
+ except getopt.error, det:
+ print det
+ usage(1)
+ not_really = 0
+ for o, v in opts:
+ if o == "-n":
+ not_really = 1
+ elif o == "-h":
+ usage(0)
+ if len(args) < 2:
+ usage(1)
+ ini = args[0]
+ cmd = args[1]
+
+ if cmd == "setup":
+ script = WalMgr(MASTER, ini, not_really)
+ script.master_setup()
+ elif cmd == "stop":
+ script = WalMgr(MASTER, ini, not_really)
+ script.master_stop()
+ elif cmd == "backup":
+ script = WalMgr(MASTER, ini, not_really)
+ script.master_backup()
+ elif cmd == "xarchive":
+ if len(args) != 4:
+ print >> sys.stderr, "usage: walmgr INI xarchive %p %f"
+ sys.exit(1)
+ script = WalMgr(MASTER, ini, not_really, 1)
+ script.master_xarchive(args[2], args[3])
+ elif cmd == "sync":
+ script = WalMgr(MASTER, ini, not_really)
+ script.master_sync()
+ elif cmd == "syncdaemon":
+ script = WalMgr(MASTER, ini, not_really, go_daemon=1)
+ script.start()
+ elif cmd == "xrestore":
+ if len(args) != 4:
+ print >> sys.stderr, "usage: walmgr INI xrestore %p %f"
+ sys.exit(1)
+ script = WalMgr(SLAVE, ini, not_really, 1)
+ script.slave_xrestore(args[2], args[3])
+ elif cmd == "restore":
+ script = WalMgr(SLAVE, ini, not_really)
+ script.slave_startup()
+ elif cmd == "boot":
+ script = WalMgr(SLAVE, ini, not_really)
+ script.slave_boot()
+ elif cmd == "pause":
+ script = WalMgr(SLAVE, ini, not_really)
+ script.slave_pause()
+ elif cmd == "continue":
+ script = WalMgr(SLAVE, ini, not_really)
+ script.slave_continue()
+ else:
+ usage(1)
+
+if __name__ == '__main__':
+ main()
+