#! /usr/bin/env python """Upgrade script for versioned schemas.""" usage = """ %prog [--user=U] [--host=H] [--port=P] --all %prog [--user=U] [--host=H] [--port=P] DB1 [ DB2 ... ]\ """ import sys, os, re, optparse import pkgloader pkgloader.require('skytools', '3.0') import skytools from skytools.natsort import natsort_key # schemas, where .upgrade.sql is enough AUTO_UPGRADE = ('pgq', 'pgq_node', 'pgq_coop', 'londiste', 'pgq_ext') # fetch list of databases DB_LIST = "select datname from pg_database "\ " where not datistemplate and datallowconn "\ " order by 1" # dont support upgrade from 2.x (yet?) version_list = [ # schema, ver, filename, recheck_func ['pgq', '3.0', None, None], ['londiste', '3.0', None, None], ['pgq_ext', '2.1', None, None], ] def is_version_ge(a, b): """Return True if a is greater or equal than b.""" va = natsort_key(a) vb = natsort_key(b) return va >= vb def is_version_gt(a, b): """Return True if a is greater than b.""" va = natsort_key(a) vb = natsort_key(b) return va > vb def check_version(curs, schema, new_ver_str, recheck_func=None, force_gt=False): funcname = "%s.version" % schema if not skytools.exists_function(curs, funcname, 0): if recheck_func is not None: return recheck_func(curs), 'NULL' else: return 0, 'NULL' q = "select %s()" % funcname curs.execute(q) old_ver_str = curs.fetchone()[0] if force_gt: ok = is_version_gt(old_ver_str, new_ver_str) else: ok = is_version_ge(old_ver_str, new_ver_str) return ok, old_ver_str class DbUpgrade(skytools.DBScript): """Upgrade all Skytools schemas in Postgres cluster.""" def upgrade(self, dbname, db): """Upgrade all schemas in single db.""" curs = db.cursor() ignore = {} for schema, ver, fn, recheck_func in version_list: # skip schema? if schema in ignore: continue if not skytools.exists_schema(curs, schema): ignore[schema] = 1 continue # new enough? ok, oldver = check_version(curs, schema, ver, recheck_func, self.options.force) if ok: continue # too old schema, no way to upgrade if fn is None: self.log.info('%s: Cannot upgrade %s, too old version', dbname, schema) ignore[schema] = 1 continue if self.options.not_really: self.log.info ("%s: Would upgrade '%s' version %s to %s", dbname, schema, oldver, ver) continue curs = db.cursor() curs.execute('begin') self.log.info("%s: Upgrading '%s' version %s to %s", dbname, schema, oldver, ver) skytools.installer_apply_file(db, fn, self.log) curs.execute('commit') def work(self): """Loop over databases.""" self.set_single_loop(1) self.load_cur_versions() # loop over all dbs dblst = self.args if self.options.all: db = self.connect_db('postgres') curs = db.cursor() curs.execute(DB_LIST) dblst = [] for row in curs.fetchall(): dblst.append(row[0]) self.close_database('db') elif not dblst: raise skytools.UsageError('Give --all or list of database names on command line') # loop over connstrs for dbname in dblst: if self.last_sigint: break self.log.info("%s: connecting", dbname) db = self.connect_db(dbname) self.upgrade(dbname, db) self.close_database('db') def load_cur_versions(self): """Load current version numbers from .upgrade.sql files.""" vrc = re.compile(r"^ \s+ return \s+ '([0-9.]+)';", re.X | re.I | re.M) for s in AUTO_UPGRADE: fn = '%s.upgrade.sql' % s fqfn = skytools.installer_find_file(fn) try: f = open(fqfn, 'r') except IOError, d: raise skytools.UsageError('%s: cannot find upgrade file: %s [%s]' % (s, fqfn, str(d))) sql = f.read() f.close() m = vrc.search(sql) if not m: raise skytools.UsageError('%s: failed to detect version' % fqfn) ver = m.group(1) cur = [s, ver, fn, None] self.log.info("Loaded %s %s from %s", s, ver, fqfn) version_list.append(cur) def connect_db(self, dbname): """Create connect string, then connect.""" elems = ["dbname='%s'" % dbname] if self.options.host: elems.append("host='%s'" % self.options.host) if self.options.port: elems.append("port='%s'" % self.options.port) if self.options.user: elems.append("user='%s'" % self.options.user) cstr = ' '.join(elems) return self.get_database('db', connstr = cstr, autocommit = 1) def init_optparse(self, parser=None): """Setup command-line flags.""" p = skytools.DBScript.init_optparse(self, parser) p.set_usage(usage) g = optparse.OptionGroup(p, "options for skytools_upgrade") g.add_option("--all", action="store_true", help = 'upgrade all databases') g.add_option("--not-really", action = "store_true", dest = "not_really", default = False, help = "don't actually do anything") g.add_option("--user", help = 'username to use') g.add_option("--host", help = 'hostname to use') g.add_option("--port", help = 'port to use') g.add_option("--force", action = "store_true", help = 'upgrade even if schema versions are new enough') p.add_option_group(g) return p def load_config(self): """Disable config file.""" return skytools.Config(self.service_name, None, user_defs = {'use_skylog': '0', 'job_name': 'db_upgrade'}) if __name__ == '__main__': script = DbUpgrade('skytools_upgrade', sys.argv[1:]) script.start()