summaryrefslogtreecommitdiff
path: root/python/skytools/sqltools.py
diff options
context:
space:
mode:
authorMarko Kreen2008-02-28 10:15:30 +0000
committerMarko Kreen2008-02-28 10:15:30 +0000
commit838e65a58084a69104ea58a6be6e26b0c982357e (patch)
tree9faf374988f2b1d33be9b3dc09ed134681c367c9 /python/skytools/sqltools.py
parente1fab145063ec48e0c74d6ff95afd4db3ac5c78d (diff)
installer logging changes from -stable
Diffstat (limited to 'python/skytools/sqltools.py')
-rw-r--r--python/skytools/sqltools.py56
1 files changed, 45 insertions, 11 deletions
diff --git a/python/skytools/sqltools.py b/python/skytools/sqltools.py
index b41b09e0..aef5966b 100644
--- a/python/skytools/sqltools.py
+++ b/python/skytools/sqltools.py
@@ -11,7 +11,7 @@ __all__ = [
"get_table_columns", "exists_schema", "exists_table", "exists_type",
"exists_function", "exists_language", "Snapshot", "magic_insert",
"CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction",
- "DBLanguage", "db_install",
+ "DBLanguage", "db_install", "installer_find_file", "installer_apply_file",
]
@@ -101,6 +101,12 @@ def exists_function(curs, function_name, nargs):
and n.nspname = %s and p.proname = %s"""
curs.execute(q, [nargs, schema, name])
res = curs.fetchone()
+
+ # if unqualified function, check builtin functions too
+ if not res[0] and function_name.find('.') < 0:
+ name = "pg_catalog." + function_name
+ return exists_function(curs, name, nargs)
+
return res[0]
def exists_language(curs, lang_name):
@@ -151,14 +157,14 @@ class Snapshot(object):
def _gen_dict_copy(tbl, row, fields):
tmp = []
for f in fields:
- v = row[f]
+ v = row.get(f)
tmp.append(quote_copy(v))
return "\t".join(tmp)
def _gen_dict_insert(tbl, row, fields):
tmp = []
for f in fields:
- v = row[f]
+ v = row.get(f)
tmp.append(quote_literal(v))
fmt = "insert into %s (%s) values (%s);"
return fmt % (tbl, ",".join(fields), ",".join(tmp))
@@ -294,8 +300,6 @@ def full_copy(tablename, src_curs, dst_curs, column_list = []):
# SQL installer
#
-def _nologger(msg): pass
-
class DBObject(object):
"""Base class for installable DB objects."""
name = None
@@ -306,13 +310,15 @@ class DBObject(object):
self.sql = sql
self.sql_file = sql_file
- def create(self, curs, logger = _nologger):
- logger('Installing %s' % self.name)
+ def create(self, curs, log = None):
+ if log:
+ log.info('Installing %s' % self.name)
if self.sql:
sql = self.sql
elif self.sql_file:
fn = self.find_file()
- logger(" Reading from %s" % fn)
+ if log:
+ log.info(" Reading from %s" % fn)
sql = open(fn, "r").read()
else:
raise Exception('object not defined')
@@ -359,11 +365,39 @@ class DBLanguage(DBObject):
def exists(self, curs):
return exists_language(curs, self.name)
-def db_install(curs, list, logger = _nologger):
+def db_install(curs, list, log = None):
"""Installs list of objects into db."""
for obj in list:
if not obj.exists(curs):
- obj.create(curs, logger)
+ obj.create(curs, log)
else:
- logger('%s is installed' % obj.name)
+ if log:
+ log.info('%s is installed' % obj.name)
+
+def installer_find_file(filename):
+ full_fn = None
+ if filename[0] == "/":
+ if os.path.isfile(filename):
+ full_fn = filename
+ else:
+ dir_list = ["."] + skytools.installer_config.sql_locations
+ for dir in dir_list:
+ fn = os.path.join(dir, filename)
+ if os.path.isfile(fn):
+ full_fn = fn
+ break
+
+ if not full_fn:
+ raise Exception('File not found: '+filename)
+ return full_fn
+
+def installer_apply_file(db, filename, log):
+ fn = installer_find_file(filename)
+ sql = open(fn, "r").read()
+ if log:
+ log.info("applying %s" % fn)
+ curs = db.cursor()
+ for stmt in skytools.parse_statements(sql):
+ log.debug(repr(stmt))
+ curs.execute(stmt)