1 /*
   2  * This file and its contents are supplied under the terms of the
   3  * Common Development and Distribution License ("CDDL"), version 1.0.
   4  * You may only use this file in accordance with the terms of version
   5  * 1.0 of the CDDL.
   6  *
   7  * A full copy of the text of the CDDL should have accompanied this
   8  * source.  A copy of the CDDL is also available via the Internet at
   9  * http://www.illumos.org/license/CDDL.
  10  */
  11 
  12 /*
  13  * Copyright 2014 Ryan Zezeski
  14  */
  15 
  16 /*
  17  * Test TCP deferred accept(). Specifically, test the data/httpfilt
  18  * socket filter modules. Deferred accept() allows the kernel to defer
  19  * the return of accept() on behalf of the application for the purpose
  20  * of making sure data is ready so that the first read() call does not
  21  * block or return EAGAIN. This test exercises the code path from the
  22  * application level, i.e. entire stack is included.
  23  *
  24  * datafilt - This module is generic TCP deferment. It assures 1 byte
  25  * is ready before accept() returns.
  26  *
  27  * httpfilt - This module is for HTTP/TCP deferment. It guesses if a
  28  * request is valid HTTP and defers accept() until the entire request
  29  * is ready, sans the body.
  30  */
  31 #include <arpa/inet.h>
  32 #include <errno.h>
  33 #include <kstat.h>
  34 #include <netinet/in.h>
  35 #include <signal.h>
  36 #include <stdarg.h>
  37 #include <stdio.h>
  38 #include <stdlib.h>
  39 #include <string.h>
  40 #include <strings.h>
  41 #include <sys/socket.h>
  42 #include <sys/types.h>
  43 #include <sys/wait.h>
  44 #include <unistd.h>
  45 
  46 #define BUFSIZE                 1024
  47 #define COMMON_HEAD                                    \
  48         "Accept: */*\r\n"                              \
  49         "Host: test.com\r\n"                           \
  50         "User-Agent: test\r\n"                         \
  51         "\r\n"
  52 #define DATAFILT                "datafilt"
  53 #define HTTPFILT                "httpfilt"
  54 #define NUM_TESTS               12
  55 #define PORT                    9876
  56 
  57 /*
  58  * Values used to indicate client/server state to each other.
  59  */
  60 enum State {
  61         /* client states */
  62         CLIENT_CONNECTED,
  63         CLIENT_SENT_ALL,
  64         CLIENT_SENT_SOME,
  65 
  66         /* server states */
  67         SERVER_LISTENING,
  68         SERVER_NO_ACCEPT
  69 };
  70 
  71 /*
  72  * Each test is defined as a tdata structure in the test_data array.
  73  * Adding a new test is a matter of making a new tdata structure and
  74  * bumping NUM_TESTS.
  75  *
  76  * td_type      The filter to use: HTTPFILT | DATAFILT.
  77  *
  78  * td_name      The name of the test, written to stdout.
  79  *
  80  * td_msgs      An array of data msgs that will be sent by the client()
  81  *              process.
  82  *
  83  * td_num_msgs  The number of msgs in td_msgs.
  84  */
  85 struct tdata {
  86         char *td_type;
  87         char *td_name;
  88         char *td_msgs[2];
  89         int td_num_msgs;
  90 };
  91 
  92 static struct tdata test_data[NUM_TESTS] = {
  93         {
  94                 HTTPFILT,
  95                 "httpfilt_GET",
  96                 {"GET /test HTTP/1.1\r\n", COMMON_HEAD},
  97                 2
  98         },
  99         {
 100                 HTTPFILT,
 101                 "httpfilt_GET_LF_only",
 102                 {"GET /test HTTP/1.1\n",
 103                         "Accept: */*\n"                                \
 104                         "Host: test.com\n"                             \
 105                         "User-Agent: test\n"                           \
 106                         "\n"},
 107                 2
 108         },
 109         {
 110                 HTTPFILT,
 111                 "httpfilt_PUT",
 112                 {"PUT /test HTTP/1.1\r\n", COMMON_HEAD},
 113                 2
 114         },
 115         {
 116                 HTTPFILT,
 117                 "httpfilt_POST",
 118                 {"POST /test HTTP/1.1\r\n", COMMON_HEAD},
 119                 2
 120         },
 121         {
 122                 HTTPFILT,
 123                 "httpfilt_HEAD",
 124                 {"HEAD /test HTTP/1.1\r\n", COMMON_HEAD},
 125                 2
 126         },
 127         {
 128                 HTTPFILT,
 129                 "httpfilt_OPTIONS",
 130                 {"OPTIONS /test HTTP/1.1\r\n", COMMON_HEAD},
 131                 2
 132         },
 133         {
 134                 HTTPFILT,
 135                 "httpfilt_TRACE",
 136                 {"TRACE /test HTTP/1.1\r\n", COMMON_HEAD},
 137                 2
 138         },
 139         {
 140                 HTTPFILT,
 141                 "httpfilt_CONNECT",
 142                 {"CONNECT /test HTTP/1.1\r\n", COMMON_HEAD},
 143                 2
 144         },
 145         {
 146                 HTTPFILT,
 147                 "httpfilt_VERSION-CONTROL",
 148                 {"VERSION-CONTROL /test HTTP/1.1\r\n", COMMON_HEAD},
 149                 2
 150         },
 151         {
 152                 /*
 153                  * Verify that a starting space (non-HTTP) is not
 154                  * deferred.
 155                  */
 156                 HTTPFILT,
 157                 "httpfilt_space",
 158                 {" some data..."},
 159                 1
 160         },
 161         {
 162                 /* Verify that a bad method is not deferred. */
 163                 HTTPFILT,
 164                 "httpfilt_bad_method",
 165                 {"badmethod /test HTTP/1.1\r\n"},
 166                 1
 167         },
 168         {
 169                 DATAFILT,
 170                 "datafilt_one_byte",
 171                 {"X"},
 172                 1
 173         }
 174 };
 175 
 176 static int      debug = 0;
 177 
 178 static int      client(int fd, char **msgs, int num_msgs);
 179 static int      server(int fd, char *td_type, char **msgs, int num_msgs);
 180 
 181 static void     dbg(const char *format, ...);
 182 static int      msgscmp(char *data, char **msgs, int num_msgs);
 183 static int      run(struct tdata tdata);
 184 static int      sendstate(int fd, enum State state);
 185 static int      waitforstate(int fd, enum State state);
 186 
 187 int
 188 main(int argc, char **argv)
 189 {
 190         int c, rc = 0;
 191 
 192         while ((c = getopt(argc, argv, "d")) != -1) {
 193                 switch (c) {
 194                 case 'd':
 195                         debug = 1;
 196                         break;
 197                 default:
 198                         break;
 199                 }
 200         }
 201 
 202         for (int i = 0; i < NUM_TESTS; i++)
 203                 rc += run(test_data[i]);
 204 
 205         return (rc);
 206 }
 207 
 208 /*
 209  * This test is made up of two cooperating processes which share state
 210  * updates over a pipe. The sequence diagram below shows the actions
 211  * taken on each side of the socket as well as the state messages. The
 212  * numbers represent the order, two steps with the same number happen
 213  * concurrently. The state messages act as synchronization points; a
 214  * process will block until it receives the expected message. Notice
 215  * that a server may accept multiple times before the socket is
 216  * created, this proves deferment. A non-deferred socket will accept
 217  * on the first try.
 218  *
 219  * CLI SOCK     CLIENT                          SERVER  SERVER SOCK
 220  *                |                                |
 221  *                |                                |    (1) listen()
 222  *                |<--------SERVER_LISTENING-------+
 223  * (2) connect()  |                                |
 224  *                +---------CLIENT_CONNECTED------>|
 225  *                |                                |    (3) accept()
 226  *                |<--------SERVER_NO_ACCEPT-------+
 227  * (4.x) send()   |                                |
 228  *                +---------CLIENT_SENT_SOME------>|
 229  *                |                                |    (5.x) accept()
 230  *                |<--------SERVER_NO_ACCEPT-------+
 231  *                |                                |
 232  *                +---------CLIENT_SENT_ALL------->|
 233  * (6) close()    |                                |    (6) accept()
 234  *                |                                |    (7) recv()
 235  *                |                                |    (8) close()
 236  *                |                                |
 237  *                =                                =
 238  *
 239  */
 240 static int
 241 run(struct tdata tdata)
 242 {
 243         int fd[2], rc;
 244         pid_t child;
 245 
 246         (void) printf("TEST STARTING: %s\n", tdata.td_name);
 247 
 248         if (pipe(fd) != 0) {
 249                 perror("failed to create pipe");
 250                 return (1);
 251         }
 252 
 253         child = fork();
 254         if (child == 0) {
 255                 exit(client(fd[1], tdata.td_msgs, tdata.td_num_msgs));
 256         } else if (child > 0) {
 257                 rc = server(fd[0], tdata.td_type, tdata.td_msgs,
 258                     tdata.td_num_msgs);
 259                 if (rc != 0)
 260                         (void) printf("TEST FAILED: %s\n", tdata.td_name);
 261                 else
 262                         (void) printf("TEST PASSED: %s\n", tdata.td_name);
 263 
 264                 (void) close(fd[0]);
 265                 (void) close(fd[1]);
 266                 (void) kill(child, SIGKILL);
 267         } else {
 268                 perror("problem with fork()\n");
 269                 (void) printf("TEST FAILED: %s\n", tdata.td_name);
 270                 rc = 1;
 271         }
 272 
 273         return (rc);
 274 }
 275 
 276 static int
 277 server(int fd, char *filt, char **msgs, int num_msgs)
 278 {
 279         int                     csock, lsock, ndeferred, status;
 280         char                    buf[BUFSIZE];
 281         kstat_ctl_t             *kc;
 282         kstat_named_t           *kn;
 283         kstat_t                 *ks;
 284         struct sockaddr_in      addr, cliaddr;
 285         socklen_t               clilen;
 286 
 287         lsock = socket(PF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
 288 
 289         if (setsockopt(lsock, SOL_FILTER, FIL_ATTACH, filt, strlen(filt) + 1) <
 290             0) {
 291                 perror("couldn't set filter");
 292                 return (1);
 293         }
 294 
 295         if (lsock == -1) {
 296                 perror("socket");
 297                 return (1);
 298         }
 299 
 300         bzero(&addr, sizeof (addr));
 301         addr.sin_family = AF_INET;
 302         addr.sin_addr.s_addr = htonl(INADDR_ANY);
 303         addr.sin_port = htons(PORT);
 304         if (bind(lsock, (struct sockaddr *)&addr, sizeof (addr)) < 0) {
 305                 perror("server failed to bind");
 306                 return (1);
 307         }
 308 
 309         if (listen(lsock, 2) < 0) {
 310                 perror("listen failed");
 311                 return (1);
 312         }
 313 
 314         dbg("(1) [server] listen()\n");
 315 
 316         char filter_name[1024] = "filter_";
 317         (void) strcat(filter_name, filt);
 318         kc = kstat_open();
 319         ks = kstat_lookup(kc, "sockfs", 0, filter_name);
 320         (void) kstat_read(kc, ks, NULL);
 321         kn = kstat_data_lookup(ks, "ndeferred");
 322         ndeferred = kn->value.ui64;
 323 
 324         if (ndeferred != 0) {
 325                 (void) fprintf(stderr, "expected 0 deferred conns but got %d\n",
 326                     ndeferred);
 327                 return (1);
 328         }
 329 
 330         if (sendstate(fd, SERVER_LISTENING) != 0)
 331                 return (1);
 332 
 333         dbg("---SERVER_LISTENING----->\n");
 334 
 335         if (waitforstate(fd, CLIENT_CONNECTED) != 0)
 336                 return (1);
 337 
 338         clilen = sizeof (cliaddr);
 339         csock = accept(lsock, (struct sockaddr *)&cliaddr, &clilen);
 340         if ((csock == -1) && (errno != EAGAIN)) {
 341                 perror("problem accepting");
 342                 return (1);
 343         } else if (csock != -1) {
 344                 (void) fprintf(stderr,
 345                     "server accpeted before any data sent\n");
 346                 return (1);
 347         }
 348 
 349         dbg("(3) [server] accept()\n");
 350 
 351         (void) kstat_chain_update(kc);
 352         (void) kstat_read(kc, ks, NULL);
 353         kn = kstat_data_lookup(ks, "ndeferred");
 354         ndeferred = kn->value.ui64;
 355         if (ndeferred != 1) {
 356                 (void) fprintf(stderr, "expected 1 deferred conns but got %d\n",
 357                     ndeferred);
 358                 return (1);
 359         }
 360 
 361         if (sendstate(fd, SERVER_NO_ACCEPT) != 0)
 362                 return (1);
 363 
 364         dbg("---SERVER_NO_ACCEPT--->\n");
 365 
 366         /* The last msg is the final msg. */
 367         for (int i = 0; i < num_msgs - 1; i++) {
 368                 if (waitforstate(fd, CLIENT_SENT_SOME) != 0)
 369                         return (1);
 370 
 371                 csock = accept(lsock, (struct sockaddr *)&cliaddr, &clilen);
 372                 if ((csock == -1) && (errno != EAGAIN)) {
 373                         perror("problem accepting");
 374                         return (1);
 375                 } else if (csock != -1) {
 376                         (void) fprintf(stderr,
 377                             "server accpeted after some data sent\n");
 378                         return (1);
 379                 }
 380 
 381                 dbg("(5.%d) [server] accept()\n", i);
 382 
 383                 if (sendstate(fd, SERVER_NO_ACCEPT) != 0)
 384                         return (1);
 385 
 386                 dbg("---SERVER_NO_ACCEPT--->\n");
 387 
 388         }
 389 
 390         if (waitforstate(fd, CLIENT_SENT_ALL) != 0)
 391                 return (1);
 392 
 393         csock = accept(lsock, (struct sockaddr *)&cliaddr, &clilen);
 394         if (csock == -1) {
 395                 perror("problem accepting");
 396                 return (1);
 397         }
 398 
 399         dbg("(6) [server] accept()\n");
 400 
 401         bzero(&buf, BUFSIZE);
 402         if (recv(csock, buf, BUFSIZE, 0) == -1) {
 403                 perror("problemn receiving data");
 404                 return (1);
 405         }
 406 
 407         dbg("(7) [server] recv()\n");
 408 
 409         if (msgscmp(buf, msgs, num_msgs) != 0) {
 410                 (void) fprintf(stderr, "data modified\n");
 411                 return (1);
 412         }
 413 
 414         (void) close(csock);
 415         (void) close(lsock);
 416 
 417         dbg("(8) [server] close()\n");
 418 
 419         (void) wait(&status);
 420         return (status);
 421 }
 422 
 423 static int
 424 client(int fd, char **msgs, int num_msgs)
 425 {
 426         int                     sock;
 427         size_t                  sz;
 428         struct sockaddr_in      addr;
 429 
 430         if (waitforstate(fd, SERVER_LISTENING) != 0)
 431                 return (1);
 432 
 433         bzero(&addr, sizeof (addr));
 434         addr.sin_family = AF_INET;
 435         addr.sin_addr.s_addr = inet_addr("127.0.0.1");
 436         addr.sin_port = htons(PORT);
 437 
 438         sock = socket(AF_INET, SOCK_STREAM, 0);
 439         if (sock < 0) {
 440                 perror("cannot create client socket");
 441                 return (1);
 442         }
 443 
 444         if (connect(sock, (struct sockaddr *)&addr, sizeof (addr)) < 0) {
 445                 (void) printf("failed to connect: %d\n", errno);
 446                 perror("cannot connect to server");
 447                 return (1);
 448         }
 449 
 450         dbg("(2) [client] connect()\n");
 451 
 452         if (sendstate(fd, CLIENT_CONNECTED) != 0)
 453                 return (1);
 454 
 455         dbg("<---CLIENT_CONNECTED-----\n");
 456 
 457         if (waitforstate(fd, SERVER_NO_ACCEPT) != 0)
 458                 return (1);
 459 
 460         for (int i = 0; i < num_msgs; i++) {
 461                 size_t len = strlen(msgs[i]);
 462                 sz = send(sock, msgs[i], len, 0);
 463                 if (sz != len) {
 464                         (void) fprintf(stderr,
 465                             "client sent %zu bytes, but should have sent %zu\n",
 466                             sz, len);
 467                         return (1);
 468                 }
 469 
 470                 dbg("(4.%d) [client] send()\n", i);
 471 
 472                 if (i == num_msgs - 1)
 473                         break;
 474 
 475                 if (sendstate(fd, CLIENT_SENT_SOME) != 0)
 476                         return (1);
 477 
 478                 dbg("<---CLIENT_SENT_SOME-----\n");
 479 
 480                 if (waitforstate(fd, SERVER_NO_ACCEPT) != 0)
 481                         return (1);
 482         }
 483 
 484         if (sendstate(fd, CLIENT_SENT_ALL) != 0)
 485                 return (1);
 486 
 487         dbg("<---CLIENT_SENT_ALL-----\n");
 488         (void) close(sock);
 489         dbg("(6) [client] close()\n");
 490 
 491         return (0);
 492 }
 493 
 494 /* Send a debug msg to stdout. */
 495 static void
 496 dbg(const char *format, ...)
 497 {
 498         va_list args;
 499 
 500         if (!debug) {
 501                 return;
 502         }
 503 
 504         va_start(args, format);
 505         (void) vprintf(format, args);
 506         va_end(args);
 507         (void) fflush(stdout);
 508 }
 509 
 510 /*
 511  * Test for byte-for-byte equality between data and msgs. Return 0 if
 512  * equal, 1 otherwise.
 513  */
 514 static int
 515 msgscmp(char *data, char **msgs, int num_msgs)
 516 {
 517         for (int i = 0, offset = 0; i < num_msgs; i++) {
 518                 int len = strlen(msgs[i]);
 519                 if (memcmp(&data[offset], msgs[i], len) != 0)
 520                         return (1);
 521 
 522                 offset += len;
 523         }
 524 
 525         return (0);
 526 }
 527 
 528 /* Send a state across the pipe. */
 529 static int sendstate(int fd, enum State state)
 530 {
 531         size_t sz;
 532 
 533         sz = write(fd, &state, sizeof (int));
 534         return ((sz == sizeof (int)) ? 0 : 1);
 535 }
 536 
 537 /* Block on the given state. */
 538 static int waitforstate(int fd, enum State state)
 539 {
 540         int i;
 541         size_t sz;
 542 
 543         sz = read(fd, &i, sizeof (int));
 544         if (sz != sizeof (int)) {
 545                 (void) fprintf(stderr, "couldn't read int\n");
 546                 return (1);
 547         }
 548 
 549         if (i != state) {
 550                 (void) fprintf(stderr,
 551                     "saw state %d but expected %d\n", i, state);
 552                 return (1);
 553         }
 554 
 555         return (0);
 556 }