diff options
Diffstat (limited to 'Xtranssock.c')
-rw-r--r-- | Xtranssock.c | 233 |
1 files changed, 216 insertions, 17 deletions
diff --git a/Xtranssock.c b/Xtranssock.c index 24269b2..23150b2 100644 --- a/Xtranssock.c +++ b/Xtranssock.c @@ -2097,47 +2097,176 @@ TRANS(SocketBytesReadable) (XtransConnInfo ciptr, BytesReadable_t *pend) #endif /* WIN32 */ } +#if XTRANS_SEND_FDS + +static void +appendFd(struct _XtransConnFd **prev, int fd, int do_close) +{ + struct _XtransConnFd *cf, *new; + + new = malloc (sizeof (struct _XtransConnFd)); + if (!new) { + /* XXX mark connection as broken */ + close(fd); + return; + } + new->next = 0; + new->fd = fd; + new->do_close = do_close; + /* search to end of list */ + for (; (cf = *prev); prev = &(cf->next)); + *prev = new; +} static int -TRANS(SocketRead) (XtransConnInfo ciptr, char *buf, int size) +removeFd(struct _XtransConnFd **prev) +{ + struct _XtransConnFd *cf; + int fd; + + if ((cf = *prev)) { + *prev = cf->next; + fd = cf->fd; + free(cf); + } else + fd = -1; + return fd; +} +static void +discardFd(struct _XtransConnFd **prev, struct _XtransConnFd *upto, int do_close) { - prmsg (2,"SocketRead(%d,%p,%d)\n", ciptr->fd, buf, size); + struct _XtransConnFd *cf, *next; -#if defined(WIN32) - { - int ret = recv ((SOCKET)ciptr->fd, buf, size, 0); -#ifdef WIN32 - if (ret == SOCKET_ERROR) errno = WSAGetLastError(); -#endif - return ret; + for (cf = *prev; cf != upto; cf = next) { + next = cf->next; + if (do_close || cf->do_close) + close(cf->fd); + free(cf); } -#else - return read (ciptr->fd, buf, size); -#endif /* WIN32 */ + *prev = upto; } +static void +cleanupFds(XtransConnInfo ciptr) +{ + /* Clean up the send list but don't close the fds */ + discardFd(&ciptr->send_fds, NULL, 0); + /* Clean up the recv list and *do* close the fds */ + discardFd(&ciptr->recv_fds, NULL, 1); +} static int -TRANS(SocketWrite) (XtransConnInfo ciptr, char *buf, int size) +nFd(struct _XtransConnFd **prev) +{ + struct _XtransConnFd *cf; + int n = 0; + + for (cf = *prev; cf; cf = cf->next) + n++; + return n; +} + +static int +TRANS(SocketRecvFd) (XtransConnInfo ciptr) +{ + prmsg (2, "SocketRecvFd(%d)\n", ciptr->fd); + return removeFd(&ciptr->recv_fds); +} +static int +TRANS(SocketSendFd) (XtransConnInfo ciptr, int fd, int do_close) { - prmsg (2,"SocketWrite(%d,%p,%d)\n", ciptr->fd, buf, size); + appendFd(&ciptr->send_fds, fd, do_close); + return 0; +} + +static int +TRANS(SocketRecvFdInvalid)(XtransConnInfo ciptr) +{ + errno = EINVAL; + return -1; +} + +static int +TRANS(SocketSendFdInvalid)(XtransConnInfo ciptr, int fd, int do_close) +{ + errno = EINVAL; + return -1; +} + +#define MAX_FDS 128 + +struct fd_pass { + struct cmsghdr cmsghdr; + int fd[MAX_FDS]; +}; + +static inline void init_msg_recv(struct msghdr *msg, struct iovec *iov, int niov, struct fd_pass *pass, int nfd) { + msg->msg_name = NULL; + msg->msg_namelen = 0; + msg->msg_iov = iov; + msg->msg_iovlen = niov; + msg->msg_control = pass; + msg->msg_controllen = sizeof (struct cmsghdr) + nfd * sizeof (int); +} + +static inline void init_msg_send(struct msghdr *msg, struct iovec *iov, int niov, struct fd_pass *pass, int nfd) { + init_msg_recv(msg, iov, niov, pass, nfd); + pass->cmsghdr.cmsg_len = msg->msg_controllen; + pass->cmsghdr.cmsg_level = SOL_SOCKET; + pass->cmsghdr.cmsg_type = SCM_RIGHTS; +} + +#endif /* XTRANS_SEND_FDS */ + +static int +TRANS(SocketRead) (XtransConnInfo ciptr, char *buf, int size) + +{ + prmsg (2,"SocketRead(%d,%p,%d)\n", ciptr->fd, buf, size); #if defined(WIN32) { - int ret = send ((SOCKET)ciptr->fd, buf, size, 0); + int ret = recv ((SOCKET)ciptr->fd, buf, size, 0); #ifdef WIN32 if (ret == SOCKET_ERROR) errno = WSAGetLastError(); #endif return ret; } #else - return write (ciptr->fd, buf, size); +#if XTRANS_SEND_FDS + { + struct msghdr msg; + struct iovec iov; + struct fd_pass pass; + + iov.iov_base = buf; + iov.iov_len = size; + + init_msg_recv(&msg, &iov, 1, &pass, MAX_FDS); + size = recvmsg(ciptr->fd, &msg, 0); + if (size >= 0 && msg.msg_controllen > sizeof (struct cmsghdr)) { + if (pass.cmsghdr.cmsg_level == SOL_SOCKET && + pass.cmsghdr.cmsg_type == SCM_RIGHTS && + !((msg.msg_flags & MSG_TRUNC) || + (msg.msg_flags & MSG_CTRUNC))) + { + int nfd = (msg.msg_controllen - sizeof (struct cmsghdr)) / sizeof (int); + int *fd = (int *) CMSG_DATA(&pass.cmsghdr); + int i; + for (i = 0; i < nfd; i++) + appendFd(&ciptr->recv_fds, fd[i], 0); + } + } + return size; + } +#else + return read(ciptr->fd, buf, size); +#endif /* XTRANS_SEND_FDS */ #endif /* WIN32 */ } - static int TRANS(SocketReadv) (XtransConnInfo ciptr, struct iovec *buf, int size) @@ -2154,11 +2283,65 @@ TRANS(SocketWritev) (XtransConnInfo ciptr, struct iovec *buf, int size) { prmsg (2,"SocketWritev(%d,%p,%d)\n", ciptr->fd, buf, size); +#if XTRANS_SEND_FDS + if (ciptr->send_fds) + { + struct msghdr msg; + struct fd_pass pass; + int nfd; + struct _XtransConnFd *cf; + int i; + + nfd = nFd(&ciptr->send_fds); + cf = ciptr->send_fds; + + /* Set up fds */ + for (i = 0; i < nfd; i++) { + pass.fd[i] = cf->fd; + cf = cf->next; + } + + init_msg_send(&msg, buf, size, &pass, nfd); + i = sendmsg(ciptr->fd, &msg, 0); + if (i > 0) + discardFd(&ciptr->send_fds, cf, 0); + return i; + } +#endif return WRITEV (ciptr, buf, size); } static int +TRANS(SocketWrite) (XtransConnInfo ciptr, char *buf, int size) + +{ + prmsg (2,"SocketWrite(%d,%p,%d)\n", ciptr->fd, buf, size); + +#if defined(WIN32) + { + int ret = send ((SOCKET)ciptr->fd, buf, size, 0); +#ifdef WIN32 + if (ret == SOCKET_ERROR) errno = WSAGetLastError(); +#endif + return ret; + } +#else +#if XTRANS_SEND_FDS + if (ciptr->send_fds) + { + struct iovec iov; + + iov.iov_base = buf; + iov.iov_len = size; + return TRANS(SocketWritev)(ciptr, &iov, 1); + } +#endif /* XTRANS_SEND_FDS */ + return write (ciptr->fd, buf, size); +#endif /* WIN32 */ +} + +static int TRANS(SocketDisconnect) (XtransConnInfo ciptr) { @@ -2211,6 +2394,9 @@ TRANS(SocketUNIXClose) (XtransConnInfo ciptr) prmsg (2,"SocketUNIXClose(%p,%d)\n", ciptr, ciptr->fd); +#if XTRANS_SEND_FDS + cleanupFds(ciptr); +#endif ret = close(ciptr->fd); if (ciptr->flags @@ -2239,6 +2425,9 @@ TRANS(SocketUNIXCloseForCloning) (XtransConnInfo ciptr) prmsg (2,"SocketUNIXCloseForCloning(%p,%d)\n", ciptr, ciptr->fd); +#if XTRANS_SEND_FDS + cleanupFds(ciptr); +#endif ret = close(ciptr->fd); return ret; @@ -2293,6 +2482,8 @@ Xtransport TRANS(SocketTCPFuncs) = { TRANS(SocketWrite), TRANS(SocketReadv), TRANS(SocketWritev), + TRANS(SocketSendFdInvalid), + TRANS(SocketRecvFdInvalid), TRANS(SocketDisconnect), TRANS(SocketINETClose), TRANS(SocketINETClose), @@ -2333,6 +2524,8 @@ Xtransport TRANS(SocketINETFuncs) = { TRANS(SocketWrite), TRANS(SocketReadv), TRANS(SocketWritev), + TRANS(SocketSendFdInvalid), + TRANS(SocketRecvFdInvalid), TRANS(SocketDisconnect), TRANS(SocketINETClose), TRANS(SocketINETClose), @@ -2374,6 +2567,8 @@ Xtransport TRANS(SocketINET6Funcs) = { TRANS(SocketWrite), TRANS(SocketReadv), TRANS(SocketWritev), + TRANS(SocketSendFdInvalid), + TRANS(SocketRecvFdInvalid), TRANS(SocketDisconnect), TRANS(SocketINETClose), TRANS(SocketINETClose), @@ -2422,6 +2617,8 @@ Xtransport TRANS(SocketLocalFuncs) = { TRANS(SocketWrite), TRANS(SocketReadv), TRANS(SocketWritev), + TRANS(SocketSendFd), + TRANS(SocketRecvFd), TRANS(SocketDisconnect), TRANS(SocketUNIXClose), TRANS(SocketUNIXCloseForCloning), @@ -2476,6 +2673,8 @@ Xtransport TRANS(SocketUNIXFuncs) = { TRANS(SocketWrite), TRANS(SocketReadv), TRANS(SocketWritev), + TRANS(SocketSendFd), + TRANS(SocketRecvFd), TRANS(SocketDisconnect), TRANS(SocketUNIXClose), TRANS(SocketUNIXCloseForCloning), |