summaryrefslogtreecommitdiff
path: root/python/skytools/psycopgwrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/skytools/psycopgwrapper.py')
-rw-r--r--python/skytools/psycopgwrapper.py21
1 files changed, 19 insertions, 2 deletions
diff --git a/python/skytools/psycopgwrapper.py b/python/skytools/psycopgwrapper.py
index aeaa7a74..6aa5835b 100644
--- a/python/skytools/psycopgwrapper.py
+++ b/python/skytools/psycopgwrapper.py
@@ -62,6 +62,7 @@ __all__ = ['connect_database']
# to the point of avoiding optimized access.
# only backwards compat thing we need is dict* methods
+import socket
import psycopg2.extensions, psycopg2.extras
from skytools.sqltools import dbdict
@@ -102,10 +103,15 @@ class _CompatConnection(psycopg2.extensions.connection):
def cursor(self):
return psycopg2.extensions.connection.cursor(self, cursor_factory = _CompatCursor)
-def connect_database(connstr):
- """Create a db connection with connect_timeout option.
+def connect_database(connstr, keepalive = True,
+ tcp_keepidle = 4 * 60, # 7200
+ tcp_keepcnt = 4, # 9
+ tcp_keepintvl = 15): # 75
+ """Create a db connection with connect_timeout and TCP keepalive.
Default connect_timeout is 15, to change put it directly into dsn.
+
+ The extra tcp_* options are Linux-specific, see `man 7 tcp` for details.
"""
# allow override
@@ -115,6 +121,16 @@ def connect_database(connstr):
# create connection
db = _CompatConnection(connstr)
+ # turn on keepalive on the connection
+ if keepalive and hasattr(socket, 'SO_KEEPALIVE'):
+ fd = db.cursor().fileno()
+ s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ if hasattr(socket, 'TCP_KEEPCNT'):
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepidle)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, tcp_keepcnt)
+ s.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepintvl)
+
# fill .server_version on older psycopg
if not hasattr(db, 'server_version'):
iso = db.isolation_level
@@ -123,5 +139,6 @@ def connect_database(connstr):
curs.execute('show server_version_num')
db.server_version = int(curs.fetchone()[0])
db.set_isolation_level(iso)
+
return db