check socket creation errors against PGINVALID_SOCKET
authorBruce Momjian <bruce@momjian.us>
Wed, 16 Apr 2014 14:45:48 +0000 (10:45 -0400)
committerBruce Momjian <bruce@momjian.us>
Wed, 16 Apr 2014 14:45:48 +0000 (10:45 -0400)
Previously, in some places, socket creation errors were checked for
negative values, which is not true for Windows because sockets are
unsigned.  This masked socket creation errors on Windows.

Backpatch through 9.0.  8.4 doesn't have the infrastructure to fix this.

src/backend/libpq/auth.c
src/backend/libpq/ip.c
src/backend/libpq/pqcomm.c
src/backend/port/win32/socket.c
src/backend/postmaster/postmaster.c
src/interfaces/libpq/fe-connect.c
src/interfaces/libpq/libpq-int.h

index 5e13145eff4135f219b3ff6c36fbd617d952aa90..f3f3b71894b9684c7d27b121ee9d4a367d1a6513 100644 (file)
@@ -1676,7 +1676,7 @@ ident_inet(hbaPort *port)
 
    sock_fd = socket(ident_serv->ai_family, ident_serv->ai_socktype,
                     ident_serv->ai_protocol);
-   if (sock_fd < 0)
+   if (sock_fd == PGINVALID_SOCKET)
    {
        ereport(LOG,
                (errcode_for_socket_access(),
@@ -1756,7 +1756,7 @@ ident_inet(hbaPort *port)
                    ident_response)));
 
 ident_inet_done:
-   if (sock_fd >= 0)
+   if (sock_fd != PGINVALID_SOCKET)
        closesocket(sock_fd);
    pg_freeaddrinfo_all(remote_addr.addr.ss_family, ident_serv);
    pg_freeaddrinfo_all(local_addr.addr.ss_family, la);
@@ -2583,7 +2583,7 @@ CheckRADIUSAuth(Port *port)
    packet->length = htons(packet->length);
 
    sock = socket(serveraddrs[0].ai_family, SOCK_DGRAM, 0);
-   if (sock < 0)
+   if (sock == PGINVALID_SOCKET)
    {
        ereport(LOG,
                (errmsg("could not create RADIUS socket: %m")));
index 3b827450af4e15772d9e39a204508d9846183b8e..eb249e9df565e0d607c177e093a0b2e673ac5daf 100644 (file)
@@ -547,7 +547,7 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
    int         error;
 
    sock = WSASocket(AF_INET, SOCK_DGRAM, 0, 0, 0, 0);
-   if (sock == SOCKET_ERROR)
+   if (sock == INVALID_SOCKET)
        return -1;
 
    while (n_ii < 1024)
@@ -670,7 +670,7 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
                total;
 
    sock = socket(AF_INET, SOCK_DGRAM, 0);
-   if (sock == -1)
+   if (sock == PGINVALID_SOCKET)
        return -1;
 
    while (n_buffer < 1024 * 100)
@@ -711,7 +711,7 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
 #ifdef HAVE_IPV6
    /* We'll need an IPv6 socket too for the SIOCGLIFNETMASK ioctls */
    sock6 = socket(AF_INET6, SOCK_DGRAM, 0);
-   if (sock6 == -1)
+   if (sock6 == PGINVALID_SOCKET)
    {
        free(buffer);
        close(sock);
@@ -788,10 +788,10 @@ pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
    char       *ptr,
               *buffer = NULL;
    size_t      n_buffer = 1024;
-   int         sock;
+   pgsocket    sock;
 
    sock = socket(AF_INET, SOCK_DGRAM, 0);
-   if (sock == -1)
+   if (sock == PGINVALID_SOCKET)
        return -1;
 
    while (n_buffer < 1024 * 100)
index 76aac975528fe0b045194f928dbb90b56d4ba23d..516c559d9f35996b6c5350d0c147a046fa3bd2d4 100644 (file)
@@ -392,7 +392,7 @@ StreamServerPort(int family, char *hostName, unsigned short portNumber,
                break;
        }
 
-       if ((fd = socket(addr->ai_family, SOCK_STREAM, 0)) < 0)
+       if ((fd = socket(addr->ai_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
        {
            ereport(LOG,
                    (errcode_for_socket_access(),
@@ -632,7 +632,7 @@ StreamConnection(pgsocket server_fd, Port *port)
    port->raddr.salen = sizeof(port->raddr.addr);
    if ((port->sock = accept(server_fd,
                             (struct sockaddr *) & port->raddr.addr,
-                            &port->raddr.salen)) < 0)
+                            &port->raddr.salen)) == PGINVALID_SOCKET)
    {
        ereport(LOG,
                (errcode_for_socket_access(),
index 0a132c4b82ed6be82da2fe18a3ca5d6c03289ac9..e349511fedfd152eff5570db5e269eeb1d264d03 100644 (file)
@@ -132,7 +132,7 @@ int
 pgwin32_waitforsinglesocket(SOCKET s, int what, int timeout)
 {
    static HANDLE waitevent = INVALID_HANDLE_VALUE;
-   static SOCKET current_socket = -1;
+   static SOCKET current_socket = INVALID_SOCKET;
    static int  isUDP = 0;
    HANDLE      events[2];
    int         r;
index 99f98205dd7f7f046cd0b988c335a184403e849b..6977e43a8f690e27eddf9844de16d339b56ebe7e 100644 (file)
@@ -2146,7 +2146,7 @@ ConnCreate(int serverFd)
 
    if (StreamConnection(serverFd, port) != STATUS_OK)
    {
-       if (port->sock >= 0)
+       if (port->sock != PGINVALID_SOCKET)
            StreamClose(port->sock);
        ConnFree(port);
        return NULL;
index b2f547a8d645c6bd4a0d2786d0a7aa030085e102..5f1bfb801ebb9ac2c5a1854129958fdffd8356ad 100644 (file)
@@ -1631,8 +1631,23 @@ keep_going:                      /* We will come back to here until there is
                    conn->raddr.salen = addr_cur->ai_addrlen;
 
                    /* Open a socket */
-                   conn->sock = socket(addr_cur->ai_family, SOCK_STREAM, 0);
-                   if (conn->sock < 0)
+                   {
+                       /*
+                        * While we use 'pgsocket' as the socket type in the
+                        * backend, we use 'int' for libpq socket values.
+                        * This requires us to map PGINVALID_SOCKET to -1
+                        * on Windows.
+                        * See http://msdn.microsoft.com/en-us/library/windows/desktop/ms740516%28v=vs.85%29.aspx
+                        */
+                       pgsocket sock = socket(addr_cur->ai_family, SOCK_STREAM, 0);
+#ifdef WIN32
+                       if (sock == PGINVALID_SOCKET)
+                           conn->sock = -1;
+                       else
+#endif
+                           conn->sock = sock;
+                   }
+                   if (conn->sock == -1)
                    {
                        /*
                         * ignore socket() failure if we have more addresses
@@ -3136,7 +3151,7 @@ internal_cancel(SockAddr *raddr, int be_pid, int be_key,
                char *errbuf, int errbufsize)
 {
    int         save_errno = SOCK_ERRNO;
-   int         tmpsock = -1;
+   pgsocket    tmpsock = PGINVALID_SOCKET;
    char        sebuf[256];
    int         maxlen;
    struct
@@ -3149,7 +3164,7 @@ internal_cancel(SockAddr *raddr, int be_pid, int be_key,
     * We need to open a temporary connection to the postmaster. Do this with
     * only kernel calls.
     */
-   if ((tmpsock = socket(raddr->addr.ss_family, SOCK_STREAM, 0)) < 0)
+   if ((tmpsock = socket(raddr->addr.ss_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
    {
        strlcpy(errbuf, "PQcancel() -- socket() failed: ", errbufsize);
        goto cancel_errReturn;
@@ -3220,7 +3235,7 @@ cancel_errReturn:
                maxlen);
        strcat(errbuf, "\n");
    }
-   if (tmpsock >= 0)
+   if (tmpsock != PGINVALID_SOCKET)
        closesocket(tmpsock);
    SOCK_ERRNO_SET(save_errno);
    return FALSE;
@@ -5281,6 +5296,15 @@ PQerrorMessage(const PGconn *conn)
    return conn->errorMessage.data;
 }
 
+/*
+ * In Windows, socket values are unsigned, and an invalid socket value
+ * (INVALID_SOCKET) is ~0, which equals -1 in comparisons (with no compiler
+ * warning). Ideally we would return an unsigned value for PQsocket() on
+ * Windows, but that would cause the function's return value to differ from
+ * Unix, so we just return -1 for invalid sockets.
+ * http://msdn.microsoft.com/en-us/library/windows/desktop/cc507522%28v=vs.85%29.aspx
+ * http://stackoverflow.com/questions/10817252/why-is-invalid-socket-defined-as-0-in-winsock2-h-c
+ */
 int
 PQsocket(const PGconn *conn)
 {
index 408aeb136b6a5914bac851de29081ae74ab06aa6..11f66e1b4d73ab4b775701c9c05d0b30449473b8 100644 (file)
@@ -364,6 +364,7 @@ struct pg_conn
    PGnotify   *notifyTail;     /* newest unreported Notify msg */
 
    /* Connection data */
+   /* See PQconnectPoll() for how we use 'int' and not 'pgsocket'. */
    int         sock;           /* Unix FD for socket, -1 if not connected */
    SockAddr    laddr;          /* Local address */
    SockAddr    raddr;          /* Remote address */