DataMaintainer can use csv file as source.
authorPetr Jelinek <git@pjmodos.net>
Sun, 1 Dec 2013 23:55:15 +0000 (00:55 +0100)
committerPetr Jelinek <git@pjmodos.net>
Sun, 1 Dec 2013 23:55:15 +0000 (00:55 +0100)
scripts/data_maintainer.py

index c1ecd42a096167ee4afd383924d9dbbacb6c158e..b78c4ea70cf2f9adbc9376287eb57a7d440a4749 100755 (executable)
@@ -10,18 +10,24 @@ Config template::
     [data_maintainer3]
     job_name        = dm_remove_expired_services
 
+    # if source is database, you need to specify dbread and sql_get_pk_list
     dbread          = dbname=sourcedb_test
+    sql_get_pk_list =
+        select username
+        from user_service
+        where expire_date < now();
+
+    # if source is csv file you need to specify fileread and optionally csv_delimiter and csv_quotechar
+    #fileread       = data.csv
+    #csv_delimiter  = ,
+    #csv_quotechar  = "
+
     dbwrite         = dbname=destdb port=1234 host=dbhost.com user=guest password=secret
     dbbefore        = dbname=destdb_test
     dbafter         = dbname=destdb_test
     dbcrash         = dbname=destdb_test
     dbthrottle      = dbname=queuedb_test
 
-    sql_get_pk_list =
-        select username
-        from user_service
-        where expire_date < now();
-
     # It is a good practice to include same where condition on target side as on read side,
     # to ensure that you are actually changing the same data you think you are,
     # especially when reading from replica database or when processing takes days.
@@ -49,7 +55,7 @@ Config template::
     #sql_throttle =
     #    select lag>'5 minutes'::interval from pgq.get_consumer_info('failoverconsumer');
 
-    # materialize query so that transaction should not be open while processing it
+    # materialize query so that transaction should not be open while processing it (only used when source is a database)
     #with_hold       = 1
 
     # how many records process to fetch at once and if batch processing is used then
@@ -75,12 +81,75 @@ Config template::
 import datetime
 import sys
 import time
+import csv
+import os.path
 
 import pkgloader
 pkgloader.require('skytools', '3.0')
 import skytools
 
 
+class DataSource (object):
+    def open(self):
+        raise NotImplementedError()
+
+    def close(self):
+        raise NotImplementedError()
+
+    def fetch(self, fetchcnt=0):
+        raise NotImplementedError()
+
+class DBDataSource (object):
+    def __init__(self, log, db, query, bres = None, with_hold = False):
+        self.log = log
+        self.db = db
+        self.query = "DECLARE data_maint_cur NO SCROLL CURSOR WITH HOLD FOR %s"\
+            if with_hold else "DECLARE data_maint_cur NO SCROLL CURSOR FOR %s"
+        self.query = self.query % query
+        self.bres = bres
+
+    def _run_query(self, query, params = None):
+        self.cur.execute(query, params)
+        self.log.debug(self.cur.query)
+        self.log.debug(self.cur.statusmessage)
+
+    def open(self):
+        self.cur = self.db.cursor()
+        self._run_query(self.query, self.bres)
+
+    def close(self):
+        self.cur.execute("CLOSE data_maint_cur")
+        if not self.withhold:
+            self.db.rollback()
+
+    def fetch(self, fetchcnt=0):
+        self._run_query("FETCH FORWARD %s FROM data_maint_cur" % fetchcnt)
+        return self.cur.fetchall()
+
+class CSVDataSource (object):
+    def __init__(self, log, filename, delimiter, quotechar):
+        self.log = log
+        self.filename = filename
+        self.delimiter = delimiter
+        self.quotechar = quotechar
+
+    def open(self):
+        self.fp = open(self.filename, 'rb')
+        self.reader = csv.DictReader(self.fp, delimiter = self.delimiter, quotechar = self.quotechar)
+
+    def close(self):
+        self.fp.close()
+
+    def fetch(self, fetchcnt=1):
+        ret = []
+        for row in self.reader:
+            ret.append(row)
+            fetchcnt = fetchcnt - 1
+            if fetchcnt <= 0:
+                break
+        return ret
+
+
 class DataMaintainer (skytools.DBScript):
     __doc__ = __doc__
     loop_delay = -1
@@ -88,8 +157,21 @@ class DataMaintainer (skytools.DBScript):
     def __init__(self, args):
         super(DataMaintainer, self).__init__("data_maintainer3", args)
 
+        # source file
+        self.fileread = self.cf.get("fileread", "")
+        if self.fileread:
+            self.fileread = os.path.expanduser(self.fileread)
+            # force single run if source is file
+            self.loop_delay = -1
+
+        self.csv_delimiter = self.cf.get("csv_delimiter", ",")
+        self.csv_quotechar = self.cf.get("csv_quotechar", '"')
+
         # query for fetching the PK-s of the data set to be maintained
-        self.sql_pk = self.cf.get("sql_get_pk_list")
+        self.sql_pk = self.cf.get("sql_get_pk_list", "")
+
+        if not self.sql_pk and not self.fileread:
+            raise ValueError("Either fileread or sql_get_pk_list must be specified in the configuration file")
 
         # query for changing data tuple ( autocommit )
         self.sql_modify = self.cf.get("sql_modify")
@@ -148,24 +230,22 @@ class DataMaintainer (skytools.DBScript):
         else:
             self.log.info("Commit in %i record batches", self.fetchcnt)
             dbw = self.get_database("dbwrite", autocommit=0)
-        if self.withhold:
-            dbr = self.get_database("dbread", autocommit=1)
-            sql = "DECLARE data_maint_cur NO SCROLL CURSOR WITH HOLD FOR %s"
+
+        if self.fileread:
+            self.datasource = CSVDataSource(self.log, self.fileread, self.csv_delimiter, self.csv_quotechar)
         else:
-            dbr = self.get_database("dbread", autocommit=0)
-            sql = "DECLARE data_maint_cur NO SCROLL CURSOR FOR %s"
-        rcur = dbr.cursor()
+            if self.withhold:
+                dbr = self.get_database("dbread", autocommit=1)
+            else:
+                dbr = self.get_database("dbread", autocommit=0)
+            self.datasource = DBDataSource(self.log, dbr, self.sql_pk, bres, self.withhold)
+
+        self.datasource.open()
         mcur = dbw.cursor()
-        rcur.execute(sql % self.sql_pk, bres) # pass results from before_query into sql_pk
-        self.log.debug(rcur.query)
-        self.log.debug(rcur.statusmessage)
 
         while True: # loop while fetch returns fetch_count rows
             self.fetch_started = time.time()
-            rcur.execute("FETCH FORWARD %s FROM data_maint_cur" % self.fetchcnt)
-            self.log.debug(rcur.query)
-            self.log.debug(rcur.statusmessage)
-            res = rcur.fetchall()
+            res = self.datasource.fetch(self.fetchcnt)
             count, lastitem = self.process_batch(res, mcur, bres)
             self.total_count += count
             if not self.autocommit:
@@ -183,9 +263,7 @@ class DataMaintainer (skytools.DBScript):
         if not self.looping:
             self.log.info("Exiting on user request")
 
-        rcur.execute("CLOSE data_maint_cur")
-        if not self.withhold:
-            dbr.rollback()
+        self.datasource.close()
         self.log.info("--- Total count: %s duration: %s ---",
                 self.total_count, datetime.timedelta(0, round(time.time() - self.started)))