Improve test_qdsocket
[fwd.git] / prod / fwd.c
1 /*
2  * Quick-and-dirty userspace port forwarder.
3  *
4  * There are a bunch of things that should be cleaned up here:
5  *    - the (limited) user-interface strings are hardcoded; should use a look-up table.
6  *    - should have some administrative logging (perhaps syslog)
7  *    - code to handle certain plausible (but unlikely) conditions remains a "TODO".  
8  *    - if there are a lot of simultaneous connections, the linear search after a select() could be a problem
9  *    - the qdSocket library functions need some performance work
10  *
11  * On the plus side, the code is small, and reasonably efficient on a 
12  * light-duty server.  I.e., "it works for me."
13  */
14
15 #include <assert.h>
16 #include <arpa/inet.h>
17 #include <errno.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 #include <sys/socket.h>
25 #include <sys/time.h>
26 #include <sys/types.h>
27 #include <unistd.h>
28
29 #include "qderrhandler.h"
30 #include "qdsocket.h"
31
32 #define MAX_ACCEPTED    50      // maximal number of concurrent connections
33
34 void usage(const char * myName) 
35 {
36     printf("Usage:  %s myPort host:otherPort\n", myName);
37     printf("        Forwards tcp received on myPort to host:otherPort.\n");
38 }
39
40 void fillFds(int * highest_sd, struct qdSocket incoming[], struct qdSocket outgoing[], fd_set * readFds, fd_set * writeFds)
41 {
42     int idx;
43
44     FD_ZERO(readFds);
45     FD_ZERO(writeFds);
46
47     // Special case:  the listening server socket, on which we read but never write
48     FD_SET(incoming[0].sd, readFds);
49     *highest_sd = incoming[0].sd;
50
51     for (idx = 1; idx < MAX_ACCEPTED; ++idx) {
52         int sd;
53
54         qdSockCheckClose(&(incoming[idx]));
55         sd = incoming[idx].sd;
56         if (sd > *highest_sd) {
57             *highest_sd = sd;
58         }
59         if (0 != sd) {
60             FD_SET(sd, readFds);
61             FD_SET(sd, writeFds);
62         }
63
64         qdSockCheckClose(&(outgoing[idx]));
65         sd = outgoing[idx].sd;
66         if (sd > *highest_sd) {
67             *highest_sd = sd;
68         }
69         if (0 != sd) {
70             FD_SET(sd, readFds);
71             FD_SET(sd, writeFds);
72         }
73     }
74 }
75
76 void writeError(struct qdSocket *to, char * buf)
77 {
78     ssize_t len = strlen(buf);
79     ssize_t written = 0;
80
81     written = qdSockWrite(to, buf, len);
82     if (written != len) {
83         printf("WARNING:  Tried to report the following on sd %ld, but could not:  %s\n",
84                (long)(to->sd), (buf ? buf : "<NULL>"));
85     }
86 }
87
88 void forwardFromTo(struct qdSocket *from, struct qdSocket *to)
89 {
90     char buf[QD_SOCK_BUF_SIZE];
91     ssize_t n;
92     ssize_t maxToRead;
93     ssize_t writeBufAvail;
94
95     maxToRead = QD_SOCK_BUF_SIZE - (to->toWrite);
96     
97     if (maxToRead <= 0) {
98         //printf("WARNING:  write buffer full on sd %ld.  Pending read on sd %ld deferred.\n",
99         //       (long)(to->sd), (long)(from->sd));
100         return;
101     }
102
103     memset(buf, 0x00, QD_SOCK_BUF_SIZE);
104
105     n = read(from->sd, buf, maxToRead);
106     if (n <= 0) {
107         qdSockClose(from);
108         qdSockClose(to);
109     }
110     else {
111         qdSockWrite(to, buf, n);
112     }
113
114     // printf("\nRead-and-write:  \"%s\"\n", buf);
115 }
116
117 void readSockets(struct qdSocket incoming[], struct qdSocket outgoing[], fd_set * readFds, const char *otherHost, int otherPort) 
118 {
119     int idx;
120
121     // Special case:  the listening server socket
122     if (FD_ISSET(incoming[0].sd, readFds)) {
123         doAccept(incoming, outgoing, otherHost, otherPort);
124     }
125
126     for (idx = 1; idx < MAX_ACCEPTED; ++idx) {
127         if (FD_ISSET(incoming[idx].sd, readFds)) {
128             forwardFromTo(&(incoming[idx]), &(outgoing[idx]));
129         }
130         if (FD_ISSET(outgoing[idx].sd, readFds)) {
131             forwardFromTo(&(outgoing[idx]), &(incoming[idx]));
132         }
133     }
134 }
135
136 void writeSockets(struct qdSocket incoming[], struct qdSocket outgoing[], fd_set * writeFds)
137 {
138     int idx;
139
140     assert(! FD_ISSET(incoming[0].sd, writeFds));
141
142     for (idx = 1; idx < MAX_ACCEPTED; ++idx) {
143         if (FD_ISSET(incoming[idx].sd, writeFds)) {
144             qdSockFlush(incoming + idx);
145         }
146         if (FD_ISSET(outgoing[idx].sd, writeFds)) {
147             qdSockFlush(outgoing + idx);
148         }
149     }
150 }
151
152 void setNonBlocking(int sd) 
153 {
154     int opt;
155
156     opt = fcntl(sd, F_GETFL);
157     if (opt < 0) {
158         qdAbort("fcntl(sd, F_GETFL) failed", opt);
159     }
160     opt |= O_NONBLOCK;
161     opt = fcntl(sd, F_SETFL, opt);
162     if (opt < 0) {
163         qdAbort("fcntl(sd, F_SETFL) failed", opt);
164     }
165 }
166
167 int doAccept(struct qdSocket incoming[], struct qdSocket outgoing[], const char *otherHost, int otherPort ) 
168 {
169     int i;
170     int new_sd;
171     int ret;
172     int sin_size = sizeof(struct sockaddr_in);
173     struct sockaddr_in theirAddr;
174     char * p;
175
176     new_sd = accept(incoming[0].sd, (struct sockaddr *)&theirAddr, &sin_size);
177     if (new_sd <= 0) {
178         qdAbort("Could not accept() incoming connection.", new_sd);
179     }
180
181     p = inet_ntoa(theirAddr.sin_addr);
182
183     printf("Connection received from %s:%d, using sd %d.\n", 
184             (p ? p : "<NULL>"),
185             theirAddr.sin_port, 
186             new_sd);
187
188     setNonBlocking(new_sd);
189
190     for (i = 0; i < MAX_ACCEPTED; ++i) {
191         if (0 == incoming[i].sd) {
192             incoming[i].sd = new_sd;
193             new_sd = (-1);
194             break;
195         }
196     }
197
198     if (new_sd >= 0) {
199         struct qdSocket temp;
200         qdSockInit(&temp);
201         temp.sd = new_sd;
202         writeError(&temp, "Sorry, the server is too busy at the moment (too many open connections).  Please wait a while and then try again.");
203         qdSockClose(&temp);
204         return (-1);
205     }
206
207     assert (0 == outgoing[i].sd);
208
209     ret = qdSockConnect(&(outgoing[i]), otherHost, otherPort);
210     if (0 != ret) {
211         // Could not connect.  Most likely, otherHost:otherPort is (temporarily?) unavailable
212         // We intentionally do not explain which host and port are unavailable, to avoid leaking 
213         // information about the internal network structure to outside parties.
214         writeError(&(incoming[i]), "Sorry, the server is currently down.  Please wait a while and then try again.");
215         qdSockClose(&(incoming[i]));
216         return (-1);
217     }
218
219     return i;
220 }
221
222
223 void forward(int myPort, const char *otherHost, int otherPort)
224 {
225     char * p = NULL;
226     int on = 0;
227     int server_sd = 0;
228     int idx = (-1);
229     int numSignalled = 0;
230     int ret = 0;
231     struct sockaddr_in myAddr;
232     struct qdSocket * incoming;
233     struct qdSocket * outgoing;
234     int highest_sd = 0;
235     fd_set readFds;
236     fd_set writeFds;
237     struct timeval timeout;
238
239     incoming = (struct qdSocket *)malloc(MAX_ACCEPTED * sizeof(struct qdSocket));
240     outgoing = (struct qdSocket *)malloc(MAX_ACCEPTED * sizeof(struct qdSocket));
241
242     if (NULL == incoming) {
243         qdAbort("Failed to calloc() incoming.  Out of memory?", 0);
244     }
245     if (NULL == outgoing) {
246         qdAbort("Failed to calloc() outgoing.  Out of memory?", 0);
247     }
248
249     for (idx = 0; idx < MAX_ACCEPTED; ++idx) {
250         qdSockInit(&(incoming[idx]));
251         qdSockInit(&(outgoing[idx]));
252     }
253
254     server_sd = socket(AF_INET, SOCK_STREAM, 0);
255     if (server_sd < 0) {
256         qdAbort("Could not allocate socket descriptor.", server_sd);
257     }
258
259     ret = setsockopt(server_sd, SOL_SOCKET, SO_REUSEADDR, (char *)&on, sizeof(on));
260     if (ret < 0) {
261         qdAbort("Could not set socket options.", ret);
262         close(server_sd);
263     }
264
265     incoming[0].sd = server_sd;
266
267     memset(&myAddr, 0x00, sizeof(struct sockaddr_in));
268     myAddr.sin_family = AF_INET;
269     myAddr.sin_port = htons(myPort);
270     myAddr.sin_addr.s_addr = htonl(INADDR_ANY);
271
272     ret = bind(server_sd, (struct sockaddr *)&myAddr, sizeof(myAddr));
273     if (ret < 0) {
274         qdAbort("Could not bind() socket.", ret);
275     }
276
277     printf("Listening on %d...\n", myPort);
278
279     ret = listen(server_sd, MAX_ACCEPTED);      // allow a backlog of up to MAX_ACCEPTED pending connections
280     if (ret < 0) {
281         qdAbort("Could not listen() on socket.", ret);
282     }
283
284     highest_sd = server_sd;
285     incoming[0].sd = server_sd;
286
287     do { 
288         fillFds(&highest_sd, incoming, outgoing, &readFds, &writeFds);
289
290         timeout.tv_sec = 1;
291         timeout.tv_usec = 0;
292
293         numSignalled = select(highest_sd + 1, &readFds, &writeFds, NULL, &timeout);
294         if (0 == numSignalled) {
295             // printf("."); fflush(stdout);
296         }
297         else {
298             readSockets(incoming, outgoing, &readFds, otherHost, otherPort);
299             writeSockets(incoming, outgoing, &writeFds);
300         }
301     } while ( 1 );
302
303     free(incoming);
304     free(outgoing);
305 }
306
307 int main(int argc, char **argv) 
308 {
309     int myPort = 0;
310     int otherPort = 0;
311     char * otherHost = "";
312     char * p = NULL;
313
314     if (3 != argc) {
315         usage(argv[0]);
316         return 1;
317     }
318
319     myPort = atoi(argv[1]);
320
321     p = strrchr(argv[2], ':');
322     if (NULL == p) {
323         usage(argv[0]);
324         return 2;
325     }
326     *p = 0;
327     ++p;
328     otherPort = atoi(p);
329
330     otherHost = argv[2];        // note that :port has been truncated off
331
332     forward(myPort, otherHost, otherPort);
333     return 0;   // unreachable
334 }
335