diff options
Diffstat (limited to 'python/skytools/psycopgwrapper.py')
-rw-r--r-- | python/skytools/psycopgwrapper.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/python/skytools/psycopgwrapper.py b/python/skytools/psycopgwrapper.py index aeaa7a74..6aa5835b 100644 --- a/python/skytools/psycopgwrapper.py +++ b/python/skytools/psycopgwrapper.py @@ -62,6 +62,7 @@ __all__ = ['connect_database'] # to the point of avoiding optimized access. # only backwards compat thing we need is dict* methods +import socket import psycopg2.extensions, psycopg2.extras from skytools.sqltools import dbdict @@ -102,10 +103,15 @@ class _CompatConnection(psycopg2.extensions.connection): def cursor(self): return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor) -def connect_database(connstr): - """Create a db connection with connect_timeout option. +def connect_database(connstr, keepalive = True, + tcp_keepidle = 4 * 60, # 7200 + tcp_keepcnt = 4, # 9 + tcp_keepintvl = 15): # 75 + """Create a db connection with connect_timeout and TCP keepalive. Default connect_timeout is 15, to change put it directly into dsn. + + The extra tcp_* options are Linux-specific, see `man 7 tcp` for details. """ # allow override @@ -115,6 +121,16 @@ def connect_database(connstr): # create connection db = _CompatConnection(connstr) + # turn on keepalive on the connection + if keepalive and hasattr(socket, 'SO_KEEPALIVE'): + fd = db.cursor().fileno() + s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if hasattr(socket, 'TCP_KEEPCNT'): + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepidle) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, tcp_keepcnt) + s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl) + # fill .server_version on older psycopg if not hasattr(db, 'server_version'): iso = db.isolation_level @@ -123,5 +139,6 @@ def connect_database(connstr): curs.execute('show server_version_num') db.server_version = int(curs.fetchone()[0]) db.set_isolation_level(iso) + return db |