525a57eef905a666f28e43bc088c2275ee3e8720
[tedtools.git] / tcp.c
1 /*
2  * Copyright (c) 2004 Teodor Sigaev <teodor@sigaev.ru>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *        notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *        notice, this list of conditions and the following disclaimer in the
12  *        documentation and/or other materials provided with the distribution.
13  * 3. Neither the name of the author nor the names of any co-contributors
14  *        may be used to endorse or promote products derived from this software
15  *        without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY CONTRIBUTORS ``AS IS'' AND ANY EXPRESS
18  * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED. IN NO EVENT SHALL CONTRIBUTORS BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
23  * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
25  * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
26  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
27  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28  */
29 #include <stdio.h>
30 #include <errno.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34 #include <fcntl.h>
35
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #else /* HAVE_POLL */
39 #ifdef HAVE_SYS_POLL_H
40 #include <sys/poll.h>
41 #else
42 #error Not defined HAVE_POLL_H or HAVE_SYS_POLL_H
43 #endif /* HAVE_SYS_POLL_H */
44 #endif /* HAVE_POLL */
45
46 #ifdef HAVE_HSTRERROR 
47 #include <netdb.h>
48 #endif
49
50
51 #include "connection.h"
52 #include "tlog.h"
53 #include "tmalloc.h"
54
55 static u_int32_t
56 setlinger( TC_Connection *cs ) {
57         struct linger ling;
58         int     val = 0;
59         socklen_t size = sizeof(val); 
60
61         if (getsockopt(cs->fd, SOL_SOCKET,SO_ERROR,&val,&size) == -1) {
62                 tlog(TL_ALARM,"getsockopt: %s:%d - %s(%d)",inet_ntoa(cs->serv_addr.sin_addr),
63                         ntohs(cs->serv_addr.sin_port), strerror(errno), errno);
64                 shutdown(cs->fd,SHUT_RDWR);
65                 close(cs->fd);
66                 cs->fd = 0;
67                 cs->state = CS_ERROR;
68                 return CS_ERROR;
69         }
70
71         if ( val ) {
72                 tlog(TL_ALARM,"getsockopt return: %s:%d - %s(%d)",inet_ntoa(cs->serv_addr.sin_addr),
73                         ntohs(cs->serv_addr.sin_port), strerror(val), val);
74                 shutdown(cs->fd,SHUT_RDWR);
75                 close(cs->fd);
76                 cs->fd = 0;
77                 cs->state = CS_ERROR;
78                 return CS_ERROR;
79         }
80
81
82         ling.l_onoff = ling.l_linger = 0;
83         if (setsockopt(cs->fd, SOL_SOCKET,SO_LINGER,(char *)&ling,sizeof(ling))==-1) {
84                 tlog(TL_ALARM,"setsockopt: LINGER %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port),
85                         strerror(errno));
86                 shutdown(cs->fd,SHUT_RDWR);
87                 close(cs->fd);
88                 cs->fd = 0;
89                 cs->state = CS_ERROR;
90                 return CS_ERROR;
91         }
92         cs->state = CS_CONNECTED;
93         return CS_CONNECTED;
94 }
95
96 u_int32_t
97 TC_ClientInitConnection(TC_Connection *cs, char *name, u_int32_t port) {
98         int flags;
99
100         cs = TC_fillConnection(cs, name, port);
101
102         cs->state = CS_OK;
103         if ((cs->fd= socket(AF_INET, SOCK_STREAM, 0)) < 0)
104                 tlog(TL_CRIT|TL_EXIT,"socket4: %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
105                         ntohs(cs->serv_addr.sin_port),strerror(errno));
106
107         if ((flags=fcntl(cs->fd,F_GETFL,0)) == -1)
108                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
109         if (fcntl(cs->fd,F_SETFL,flags|O_NDELAY) < 0 )
110                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
111
112         if (bind(cs->fd, (struct sockaddr *) &(cs->serv_addr), sizeof(cs->serv_addr)) < 0)
113                 tlog(TL_CRIT|TL_EXIT, "cannot bind to %s address: %s",
114                         inet_ntoa(cs->serv_addr.sin_addr), strerror(errno));
115         
116         if (listen(cs->fd, 0) < 0)
117                 tlog(TL_CRIT|TL_EXIT, "cannot listen to %s address: %s",
118                         inet_ntoa(cs->serv_addr.sin_addr), strerror(errno));
119         
120         return CS_OK;
121 }
122
123 TC_Connection*
124 TC_AcceptTcp(TC_Connection *cs) {
125         TC_Connection *nc;
126         struct sockaddr_in cli_addr;
127         int ret, flags;
128         socklen_t clilen = sizeof(cli_addr);
129
130         cs->state = CS_READ;
131         if ( (ret = accept(cs->fd,(struct sockaddr *)&cli_addr, &clilen)) < 0 ) {
132                 if ( errno == EAGAIN || errno == EWOULDBLOCK )
133                         return NULL;
134                 tlog(TL_ALARM,"TC_AcceptTcp: accept: %s", strerror(errno));
135                 return NULL;
136         }
137         nc = (TC_Connection*)t0malloc(sizeof(TC_Connection));
138
139         nc->fd = ret;
140         if ((flags=fcntl(nc->fd,F_GETFL,0)) == -1)
141                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
142         if (fcntl(nc->fd,F_SETFL,flags|O_NDELAY) < 0 )
143                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
144         memcpy( &(nc->serv_addr), &cli_addr, clilen );
145         nc->state = CS_CONNECTED;
146
147         setlinger(nc);
148         return nc;
149 }
150
151 TC_Connection *
152 TC_fillConnection(TC_Connection *sc, char *name, u_int32_t port) {
153         if ( !sc ) 
154                 sc = (TC_Connection *)t0malloc(sizeof(TC_Connection));
155         sc->serv_addr.sin_family = AF_INET;
156         sc->serv_addr.sin_addr.s_addr = (name) ? inet_addr(name) : htonl(INADDR_ANY);
157         sc->serv_addr.sin_port = htons(port);
158         sc->state = CS_NOTINITED;
159         return sc; 
160 }
161
162
163 u_int32_t
164 TC_ServerInitConnect( TC_Connection     *cs ) {
165         int flags;
166
167         if ( cs->state == CS_ERROR )
168                 return CS_ERROR;
169
170         if ((cs->fd= socket(AF_INET, SOCK_STREAM, 0)) < 0) {
171                 tlog(TL_CRIT,"socket4: %s:%d - %s",inet_ntoa(cs->serv_addr.sin_addr),
172                         ntohs(cs->serv_addr.sin_port),strerror(errno));
173                 cs->state  = CS_ERROR;
174                 return  CS_ERROR;
175         }
176
177         if ((flags=fcntl(cs->fd,F_GETFL,0)) == -1)
178                 tlog(TL_ALARM,"fcntl F_GETFL - %s",strerror(errno));
179         if (fcntl(cs->fd,F_SETFL,flags|O_NDELAY) < 0 )
180                 tlog(TL_ALARM,"fcntl O_NDELAY - %s",strerror(errno));
181
182         if ( connect(cs->fd, (struct sockaddr *) &(cs->serv_addr),
183                 sizeof(struct sockaddr_in)) < 0 ) {
184                 if ( errno == EINPROGRESS || errno == EALREADY ) {
185                         cs->state = CS_INPROCESS;
186                         return CS_INPROCESS; 
187                 } else if (errno != EISCONN && errno != EALREADY &&
188                         errno != EWOULDBLOCK && errno != EAGAIN) {
189                         tlog(TL_DEBUG,"open4: %s:%d - %s",
190                                 inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port),
191                                 strerror(errno));
192                         shutdown(cs->fd,SHUT_RDWR);
193                         close(cs->fd);
194                         cs->fd = 0;
195                 } else {
196                         tlog(TL_DEBUG,"nonblock connect: %s:%d - %s [%d]",
197                                 inet_ntoa(cs->serv_addr.sin_addr),
198                                 ntohs(cs->serv_addr.sin_port),
199                                 strerror(errno),errno);
200                 }
201                 cs->state = CS_ERROR;
202                 return CS_ERROR;
203         }
204
205         cs->state = CS_INPROCESS;
206         return TC_ServerConnect( cs );
207 }
208         
209
210 u_int32_t
211 TC_ServerConnect( TC_Connection *cs ) {
212         struct pollfd   pfd;
213         int ret;
214
215         if ( cs->state != CS_INPROCESS )
216                 return cs->state;
217
218         pfd.fd = cs->fd;
219         pfd.events = POLLOUT;
220         pfd.revents = 0;
221         ret = poll( &pfd, 1, 0 );
222         if ( ret<0 ) {
223                 tlog( TL_CRIT, "TC_ServerConnect: poll: %s",
224                         strerror(errno));
225                 cs->state = CS_ERROR;
226                 return CS_ERROR;
227         } else if ( ret == 0 ) 
228                 return CS_INPROCESS;
229
230         if ( (pfd.revents & (POLLHUP | POLLNVAL | POLLERR)) ) {
231                 tlog( TL_CRIT, "TC_ServerConnect: poll return connect error for %s:%d",
232                         inet_ntoa(cs->serv_addr.sin_addr), ntohs(cs->serv_addr.sin_port));
233                 cs->state = CS_ERROR;
234                 return CS_ERROR;
235         }
236
237         if ( ! (pfd.revents & POLLOUT) )
238                 return CS_INPROCESS;
239
240
241         return setlinger( cs );
242 }
243
244 int
245 TC_ReadyIO( TC_Connection **cs, int number, int timeout ) {
246         struct pollfd   *pfd;
247         int ret,i, fdnum=0;
248
249         if ( number==0 || cs ==NULL ) {
250                 usleep( timeout * 1000.0 );
251                 return 0;
252         }
253         pfd = (struct pollfd*) tmalloc( sizeof(struct pollfd) * number );
254
255         for(i=0; i<number;i++) {
256                 if ( cs[i]->fd>0 && (cs[i]->state == CS_READ || cs[i]->state == CS_SEND) ) {
257                         pfd[fdnum].fd = cs[i]->fd;
258                         pfd[fdnum].events = ( cs[i]->state == CS_READ ) ? POLLIN : POLLOUT;
259                         pfd[fdnum].revents = 0;
260                         fdnum++;
261                 }
262                 cs[i]->readyio=0;
263         }
264         ret = poll( pfd, fdnum, timeout );
265         if ( ret<0 ) {
266                 tlog( TL_CRIT, "TC_ReadyIO: poll: %s",
267                         strerror(errno));
268                 tfree(pfd);
269                 return 0;
270         }
271
272         if ( ret == 0 ) {
273                 tfree(pfd);
274                 return 0;
275         }
276
277         fdnum=0; ret=0;
278         for(i=0; i<number;i++) {
279                 if ( cs[i]->fd>0 && (cs[i]->state == CS_READ || cs[i]->state == CS_SEND) ) {
280                         if ( pfd[fdnum].revents & (POLLHUP | POLLNVAL | POLLERR) ) { 
281                                 tlog( TL_ALARM, "TC_ReadyIO: poll return error for %s:%d",
282                                         inet_ntoa(cs[i]->serv_addr.sin_addr), 
283                                         ntohs(cs[i]->serv_addr.sin_port));
284                                 cs[i]->state = CS_ERROR;
285                                 ret = 1;
286                         } else if ( pfd[fdnum].revents & ( ( cs[i]->state == CS_READ ) ? POLLIN : POLLOUT ) ) {
287                                 cs[i]->readyio=1;
288                                 ret = 1;
289                         }
290                         fdnum++;
291                 }
292         }
293
294         tfree(pfd);
295         return ret;
296 }
297
298 u_int32_t
299 TC_Send( TC_Connection *cs ) {
300         int sz;
301         
302         if ( cs->state == CS_ERROR )
303                 return CS_ERROR;
304
305         if ( cs->state != CS_SEND ) {
306                 cs->state = CS_SEND;
307                 cs->ptr = cs->buf;
308         }
309
310         if ( cs->ptr - cs->buf >= cs->len ) {
311                 cs->state = CS_FINISHSEND;
312                 return CS_FINISHSEND;
313         }
314
315         if ((sz=write(cs->fd, cs->ptr, cs->len - (cs->ptr - cs->buf)))==0 ||
316                 (sz < 0 && (errno == EWOULDBLOCK || errno == EAGAIN))) {
317
318                 /* SunOS 4.1.x, are broken and select() says that
319                  * O_NDELAY sockets are always writable even when
320                  * they're actually not.
321                  */
322                 cs->state = CS_SEND;
323                 return CS_SEND;
324         }
325         if ( sz<0 ) {
326                 if (errno != EPIPE && errno != EINVAL)
327                         tlog(TL_ALARM, "write[%s:%d] - %s",
328                                 inet_ntoa(cs->serv_addr.sin_addr),
329                                 ntohs(cs->serv_addr.sin_port), 
330                                 strerror(errno));
331                 cs->state = CS_ERROR;
332                 return CS_ERROR;
333         }
334
335         cs->ptr += sz;
336
337         if ( cs->ptr - cs->buf >= cs->len ) {
338                 cs->state = CS_FINISHSEND;
339                 return CS_FINISHSEND;
340         }
341         
342         return cs->state;
343 }
344
345 static void 
346 resizeCS( TC_Connection *cs, int sz ) {
347         int diff = cs->ptr - cs->buf;
348         if ( cs->len >= sz )
349                 return; 
350         cs->len = sz;
351         cs->buf = (char*)trealloc( (void*)cs->buf, cs->len );
352         cs->ptr = cs->buf + diff;
353 }
354
355 u_int32_t
356 TC_Read( TC_Connection *cs ) {
357         int sz, totalread = -1, toread=0, alreadyread;
358
359         if ( cs->state == CS_ERROR )
360                 return CS_ERROR;
361
362         if (cs->state != CS_READ ) {
363                 cs->state = CS_READ;
364                 cs->ptr = cs->buf;
365         }
366
367         alreadyread = cs->ptr - cs->buf;
368         if ( alreadyread < sizeof(u_int32_t) ) {
369                 toread = sizeof(u_int32_t) - alreadyread;
370                 resizeCS(cs, sizeof(u_int32_t));
371         } else {
372                 totalread = *(u_int32_t*)(cs->buf);
373                 toread = totalread - alreadyread;
374                 if ( toread == 0 ) {
375                         cs->state = CS_FINISHREAD;
376                         return CS_FINISHREAD;
377                 }
378                 resizeCS(cs, totalread);
379         }
380
381         if ((sz=read( cs->fd, cs->ptr, toread))<0) {
382                 if (errno == EAGAIN || errno == EINTR) {
383                         cs->state = CS_READ;
384                         return CS_READ;
385                 }
386                 tlog(TL_ALARM,"read: finish - %s",strerror(errno));
387                 cs->state = CS_ERROR;
388                 return CS_ERROR;
389         }
390         
391
392         cs->ptr += sz;
393         alreadyread += sz;
394         if ( sz == 0 && alreadyread != totalread ) {
395                 tlog(TL_ALARM,"read: disconnecting");
396                 cs->state = CS_ERROR;
397                 return CS_ERROR;
398         }
399         cs->state = ( alreadyread == totalread ) ? CS_FINISHREAD : CS_READ;
400         return cs->state;
401 }
402
403 void
404 TC_FreeConnection( TC_Connection *cs ) {
405         if ( cs->state == CS_CLOSED )
406                 return;
407         if ( cs->buf ) {
408                 tfree(cs->buf);
409                 cs->buf = NULL;
410         }
411         if ( cs->fd && cs->state != CS_NOTINITED ) {
412                 shutdown(cs->fd,SHUT_RDWR);
413                 close(cs->fd);
414         }
415         cs->fd = 0;
416         cs->state = CS_CLOSED;
417 }
418
419 u_int32_t 
420 TC_Talk( TC_Connection *cs ) {
421         u_int32_t ret = TC_ServerInitConnect( cs );
422
423         while( ret == CS_INPROCESS ) {
424                 ret =  TC_ServerConnect(cs);
425         }
426
427         if ( ret != CS_CONNECTED )
428                 return ret;
429         
430         while( ret != CS_FINISHSEND ) {
431                 ret = TC_Send(cs);
432                 if ( ret == CS_ERROR ) return ret;
433         }
434
435         cs->state = CS_READ;
436         cs->ptr = cs->buf;
437         while( cs->state != CS_FINISHREAD ) {
438                 while( !TC_ReadyIO( &cs, 1, 100) );
439                 if ( TC_Read(cs) == CS_ERROR ) return CS_ERROR;
440         }
441
442         return CS_OK; 
443 }
444
445