diff options
| author | martinko | 2013-03-04 09:55:03 +0000 |
|---|---|---|
| committer | martinko | 2013-03-04 09:55:03 +0000 |
| commit | 5d960572e1d3dfcb59d0d8529c02b0e1ab9a8461 (patch) | |
| tree | 78c4292fe2d5a9f33c28d71dffe8903b682fdd87 /python | |
| parent | 5d2fbaf63bb0ead9ce3dcc95f4d4dd1173813008 (diff) | |
| parent | 07a2bd9d5f70cce178585aa3b8468ce3870b4b60 (diff) | |
Merge branch 'master' of https://github.com/markokr/skytools
Diffstat (limited to 'python')
| -rw-r--r-- | python/londiste/playback.py | 4 | ||||
| -rw-r--r-- | python/pgq/cascade/admin.py | 92 | ||||
| -rw-r--r-- | python/skytools/__init__.py | 2 | ||||
| -rw-r--r-- | python/skytools/parsing.py | 58 |
4 files changed, 141 insertions, 15 deletions
diff --git a/python/londiste/playback.py b/python/londiste/playback.py index 4f509dcf..4fa87014 100644 --- a/python/londiste/playback.py +++ b/python/londiste/playback.py @@ -277,6 +277,10 @@ class Replicator(CascadedWorker): # target database db = dbname=somedb host=127.0.0.1 + # public connect string for target node, which other nodes use + # to access this one. + #public_node_location = + # how many tables can be copied in parallel #parallel_copies = 1 diff --git a/python/pgq/cascade/admin.py b/python/pgq/cascade/admin.py index ed44dad7..dd9215d8 100644 --- a/python/pgq/cascade/admin.py +++ b/python/pgq/cascade/admin.py @@ -29,9 +29,9 @@ command_usage = """\ %prog [options] INI CMD [subcmd args] Node Initialization: - create-root NAME PUBLIC_CONNSTR - create-branch NAME PUBLIC_CONNSTR --provider=<public_connstr> - create-leaf NAME PUBLIC_CONNSTR --provider=<public_connstr> + create-root NAME [PUBLIC_CONNSTR] + create-branch NAME [PUBLIC_CONNSTR] --provider=<public_connstr> + create-leaf NAME [PUBLIC_CONNSTR] --provider=<public_connstr> Initializes node. Node Administration: @@ -141,24 +141,43 @@ class CascadeAdmin(skytools.AdminScript): db = self.get_database("db") self.install_code(db) - def cmd_create_root(self, node_name, node_location): - return self.create_node('root', node_name, node_location) + def cmd_create_root(self, node_name, *args): + return self.create_node('root', node_name, args) - def cmd_create_branch(self, node_name, node_location): - return self.create_node('branch', node_name, node_location) + def cmd_create_branch(self, node_name, *args): + return self.create_node('branch', node_name, args) - def cmd_create_leaf(self, node_name, node_location): - return self.create_node('leaf', node_name, node_location) + def cmd_create_leaf(self, node_name, *args): + return self.create_node('leaf', node_name, args) - def create_node(self, node_type, node_name, node_location): + def create_node(self, node_type, node_name, args): """Generic node init.""" provider_loc = self.options.provider if node_type not in ('root', 'branch', 'leaf'): raise Exception('unknown node type') + # load public location + if len(args) > 1: + raise UsageError('Too many args, only public connect string allowed') + elif len(args) == 1: + node_location = args[0] + else: + node_location = self.cf.get('public_node_location', '') + if not node_location: + raise UsageError('Node public location must be given either in command line or config') + + # check if sane + ok = 0 + for k, v in skytools.parse_connect_string(node_location): + if k in ('host', 'service'): + ok = 1 + break + if not ok: + raise UsageError('No host= in public connect string, bad idea') + # connect to database - db = self.get_database("new_node", connstr = node_location) + db = self.get_database("db") # check if code is installed self.install_code(db) @@ -170,6 +189,9 @@ class CascadeAdmin(skytools.AdminScript): self.log.info("Node is already initialized as %s", info['node_type']) return + # check if public connstr is sane + self.check_public_connstr(db, node_location) + self.log.info("Initializing node") node_attrs = {} @@ -257,6 +279,43 @@ class CascadeAdmin(skytools.AdminScript): self.log.info("Done") + def check_public_connstr(self, db, pub_connstr): + """Look if public and local connect strings point to same db's. + """ + pub_db = self.get_database("pub_db", connstr = pub_connstr) + curs1 = db.cursor() + curs2 = pub_db.cursor() + q = "select oid, datname, txid_current() as txid, txid_current_snapshot() as snap"\ + " from pg_catalog.pg_database where datname = current_database()" + curs1.execute(q) + res1 = curs1.fetchone() + db.commit() + + curs2.execute(q) + res2 = curs2.fetchone() + pub_db.commit() + + curs1.execute(q) + res3 = curs1.fetchone() + db.commit() + + self.close_database("pub_db") + + failure = 0 + if (res1['oid'], res1['datname']) != (res2['oid'], res2['datname']): + failure += 1 + + sn1 = skytools.Snapshot(res1['snap']) + tx = res2['txid'] + sn2 = skytools.Snapshot(res3['snap']) + if sn1.contains(tx): + failure += 2 + elif not sn2.contains(tx): + failure += 4 + + if failure: + raise UsageError("Public connect string points to different database than local connect string (fail=%d)" % failure) + def extra_init(self, node_type, node_db, provider_db): """Callback to do specific init.""" pass @@ -363,12 +422,17 @@ class CascadeAdmin(skytools.AdminScript): nodes = Queue.Queue() # launch workers and wait - n = max (min (members.qsize() >> 2, 100), 1) - for i in range(n): + num_nodes = len(self.queue_info.member_map) + num_threads = max (min (num_nodes / 4, 100), 1) + tlist = [] + for i in range(num_threads): t = threading.Thread (target = self._cmd_status_worker, args = (members, nodes)) t.daemon = True t.start() - members.join() + tlist.append(t) + #members.join() + for t in tlist: + t.join() while True: try: diff --git a/python/skytools/__init__.py b/python/skytools/__init__.py index 8f2c52a3..048d41fb 100644 --- a/python/skytools/__init__.py +++ b/python/skytools/__init__.py @@ -47,7 +47,9 @@ _symbols = { # skytools.parsing 'dedent': 'skytools.parsing:dedent', 'hsize_to_bytes': 'skytools.parsing:hsize_to_bytes', + 'merge_connect_string': 'skytools.parsing:merge_connect_string', 'parse_acl': 'skytools.parsing:parse_acl', + 'parse_connect_string': 'skytools.parsing:parse_connect_string', 'parse_logtriga_sql': 'skytools.parsing:parse_logtriga_sql', 'parse_pgarray': 'skytools.parsing:parse_pgarray', 'parse_sqltriga_sql': 'skytools.parsing:parse_sqltriga_sql', diff --git a/python/skytools/parsing.py b/python/skytools/parsing.py index decc7e7e..318b1bf9 100644 --- a/python/skytools/parsing.py +++ b/python/skytools/parsing.py @@ -7,7 +7,8 @@ import skytools __all__ = [ "parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements", 'sql_tokenizer', 'parse_sqltriga_sql', - "parse_acl", "dedent", "hsize_to_bytes"] + "parse_acl", "dedent", "hsize_to_bytes", + "parse_connect_string", "merge_connect_string"] _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X) @@ -445,6 +446,61 @@ def hsize_to_bytes (input): bytes = int(m.group(1)) * 1024 ** units.index(m.group(2).upper()) return bytes +# +# Connect string parsing +# + +_cstr_rx = r""" \s* (\w+) \s* = \s* ( ' ( \\.| [^'\\] )* ' | \S+ ) \s* """ +_cstr_unesc_rx = r"\\(.)" +_cstr_badval_rx = r"[\s'\\]" +_cstr_rc = None +_cstr_unesc_rc = None +_cstr_badval_rc = None + +def parse_connect_string(cstr): + r"""Parse Postgres connect string. + + >>> parse_connect_string("host=foo") + [('host', 'foo')] + >>> parse_connect_string(r" host = foo password = ' f\\\o\'o ' ") + [('host', 'foo'), ('password', "' f\\o'o '")] + """ + global _cstr_rc, _cstr_unesc_rc + if not _cstr_rc: + _cstr_rc = re.compile(_cstr_rx, re.X) + _cstr_unesc_rc = re.compile(_cstr_unesc_rx) + pos = 0 + res = [] + while pos < len(cstr): + m = _cstr_rc.match(cstr, pos) + if not m: + raise ValueError('Invalid connect string') + pos = m.end() + k = m.group(1) + v = m.group(2) + if v[0] == "'": + v = _cstr_unesc_rc.sub(r"\1", v) + res.append( (k,v) ) + return res + +def merge_connect_string(cstr_arg_list): + """Put fragments back together. + + >>> merge_connect_string([('host', 'ip'), ('pass', ''), ('x', ' ')]) + "host=ip pass='' x=' '" + """ + global _cstr_badval_rc + if not _cstr_badval_rc: + _cstr_badval_rc = re.compile(_cstr_badval_rx) + + buf = [] + for k, v in cstr_arg_list: + if not v or _cstr_badval_rc.search(v): + v = v.replace('\\', r'\\') + v = v.replace("'", r"\'") + v = "'" + v + "'" + buf.append("%s=%s" % (k, v)) + return ' '.join(buf) if __name__ == '__main__': import doctest |
