]> granicus.if.org Git - ejabberd/commitdiff
* src/tls/tls_drv.c: Don't create a SSL context on every
authorAlexey Shchepin <alexey@process-one.net>
Mon, 10 Nov 2008 14:32:51 +0000 (14:32 +0000)
committerAlexey Shchepin <alexey@process-one.net>
Mon, 10 Nov 2008 14:32:51 +0000 (14:32 +0000)
connection and disable SSLv2 on outgoing connections (EJAB-781)

SVN Revision: 1675

ChangeLog
src/tls/tls_drv.c

index fe8a7e3f98709cb25315867079cb69b3df6da12e..d43576dcc31e0cc84ec837eea8b1f9f87319103b 100644 (file)
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,15 +1,3 @@
-2008-11-08  Mickael Remond  <mremond@process-one.net>
-
-       * src/ejabberd_s2s_out.erl: exports the DNS resolution
-       function.
-
-2008-11-06  Badlop  <badlop@process-one.net>
-
-       * src/extauth.erl: When the extauth call fails or timeouts, deny
-       authorization. Use two timeouts: 60s for script initialization and
-       10s for regular calls. (thanks to Kevin Crosbie from
-       Ravenpack) (EJAB-627)
-
 2008-11-03  Alexey Shchepin  <alexey@process-one.net>
 
        * src/ejabberd_c2s.erl: Disable zlib when STARTTLS is required
index 3efe72cf78cf9a21c8b7e154464f95f208441fda..b90cab87c6d6cc66012242461fb0b8153302eb64 100644 (file)
 #include <erl_driver.h>
 #include <openssl/ssl.h>
 #include <openssl/err.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#include <stdint.h>
 
 
 #define BUF_SIZE 1024
 
 typedef struct {
       ErlDrvPort port;
-      SSL_CTX *ctx;
       BIO *bio_read;
       BIO *bio_write;
       SSL *ssl;
 } tls_data;
 
+#ifdef _WIN32
+typedef unsigned __int32 uint32_t;
+#endif
+
+/*
+ * str_hash is based on the public domain code from
+ * http://www.burtleburtle.net/bob/hash/doobs.html
+ */
+static uint32_t str_hash(char *s)
+{
+   unsigned char *key = (unsigned char *)s;
+   uint32_t hash = 0;
+   size_t i;
+
+   for (i = 0; key[i] != 0; i++) {
+      hash += key[i];
+      hash += (hash << 10);
+      hash ^= (hash >> 6);
+   }
+   hash += (hash << 3);
+   hash ^= (hash >> 11);
+   hash += (hash << 15);
+   return hash;
+}
+
+/* Linear hashing */
+
+#define MIN_LEVEL 8
+#define MAX_LEVEL 20
+
+struct bucket {
+      uint32_t hash;
+      char *key_file;
+      time_t mtime;
+      SSL_CTX *ssl_ctx;
+      struct bucket *next;
+};
+
+struct hash_table {
+      int split;
+      int level;
+      struct bucket **buckets;
+      int size;
+};
+
+struct hash_table ht;
+
+static void init_hash_table()
+{
+   size_t size = 1 << (MIN_LEVEL + 1);
+   size_t i;
+   ht.buckets = (struct bucket **)driver_alloc(sizeof(struct bucket *) * size);
+   ht.split = 0;
+   ht.level = MIN_LEVEL;
+   for (i = 0; i < size; i++)
+      ht.buckets[i] = NULL;
+   
+}
+
+static void hash_table_insert(char *key_file, time_t mtime,
+                            SSL_CTX *ssl_ctx)
+{
+   int level, split;
+   uint32_t hash = str_hash(key_file);
+   size_t bucket;
+   int do_split = 0;
+   struct bucket *el;
+   struct bucket *new_bucket_el;
+
+   split = ht.split;
+   level = ht.level;
+
+   bucket = hash & ((1 << level) - 1);
+   if (bucket < split)
+      bucket = hash & ((1 << (level + 1)) - 1);
+
+   el = ht.buckets[bucket];
+   while (el != NULL) {
+      if (el->hash == hash && strcmp(el->key_file, key_file) == 0) {
+        el->mtime = mtime;
+        if (el->ssl_ctx != NULL)
+           SSL_CTX_free(el->ssl_ctx);
+        el->ssl_ctx = ssl_ctx;
+        break;
+      }
+      el = el->next;
+   }
+
+   if (el == NULL) {
+      if (ht.buckets[bucket] != NULL)
+        do_split = !0;
+
+      new_bucket_el = (struct bucket *)driver_alloc(sizeof(struct bucket));
+      new_bucket_el->hash = hash;
+      new_bucket_el->key_file = (char *)driver_alloc(strlen(key_file) + 1);
+      strcpy(new_bucket_el->key_file, key_file);
+      new_bucket_el->mtime = mtime;
+      new_bucket_el->ssl_ctx = ssl_ctx;
+      new_bucket_el->next = ht.buckets[bucket];
+      ht.buckets[bucket] = new_bucket_el;
+   }
+
+   if (do_split) {
+      struct bucket **el_ptr = &ht.buckets[split];
+      size_t new_bucket = split + (1 << level);
+      while (*el_ptr != NULL) {
+        uint32_t hash = (*el_ptr)->hash;
+        if ((hash & ((1 << (level + 1)) - 1)) == new_bucket) {
+           struct bucket *moved_el = *el_ptr;
+           *el_ptr = (*el_ptr)->next;
+           moved_el->next = ht.buckets[new_bucket];
+           ht.buckets[new_bucket] = moved_el;
+        } else
+           el_ptr = &(*el_ptr)->next;
+      }
+      split++;
+      if (split == 1 << level) {
+        size_t size;
+        size_t i;
+        split = 0;
+        level++;
+        size = 1 << (level + 1);
+        ht.split = split;
+        ht.level = level;
+        ht.buckets = (struct bucket **)
+           driver_realloc(ht.buckets, sizeof(struct bucket *) * size);
+        for (i = 1 << level; i < size; i++)
+           ht.buckets[i] = NULL;
+      } else
+        ht.split = split;
+   }
+}
+
+static SSL_CTX *hash_table_lookup(char *key_file, time_t *pmtime)
+{
+   int level, split;
+   uint32_t hash = str_hash(key_file);
+   size_t bucket;
+   struct bucket *el;
+
+   split = ht.split;
+   level = ht.level;
+
+   bucket = hash & ((1 << level) - 1);
+   if (bucket < split)
+      bucket = hash & ((1 << (level + 1)) - 1);
+
+   el = ht.buckets[bucket];
+   while (el != NULL) {
+      if (el->hash == hash && strcmp(el->key_file, key_file) == 0) {
+        *pmtime = el->mtime;
+        return el->ssl_ctx;
+      }
+      el = el->next;
+   }
+
+   return NULL;
+}
+
 
 static ErlDrvData tls_drv_start(ErlDrvPort port, char *buff)
 {
    tls_data *d = (tls_data *)driver_alloc(sizeof(tls_data));
    d->port = port;
-   d->ctx = NULL;
    d->bio_read = NULL;
    d->bio_write = NULL;
    d->ssl = NULL;
@@ -57,12 +218,46 @@ static void tls_drv_stop(ErlDrvData handle)
    if (d->ssl != NULL)
       SSL_free(d->ssl);
 
-   if (d->ctx != NULL)
-      SSL_CTX_free(d->ctx);
-
    driver_free((char *)handle);
 }
 
+static void tls_drv_finish()
+{
+   int level;
+   struct bucket *el;
+   int i;
+
+   level = ht.level;
+   for (i = 0; i < 1 << (level + 1); i++) {
+      el = ht.buckets[i];
+      while (el != NULL) {
+        if (el->ssl_ctx != NULL)
+           SSL_CTX_free(el->ssl_ctx);
+        driver_free(el->key_file);
+        el = el->next;
+      }
+   }
+
+   driver_free(ht.buckets);
+}
+
+static int is_key_file_modified(char *file, time_t *key_file_mtime)
+{
+   struct stat file_stat;
+
+   if (stat(file, &file_stat))
+   {
+      *key_file_mtime = 0;
+      return 1;
+   } else {
+      if (*key_file_mtime != file_stat.st_mtime)
+      {
+        *key_file_mtime = file_stat.st_mtime;
+        return 1;
+      } else
+        return 0;
+   }
+}
 
 static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
 {
@@ -122,29 +317,41 @@ static int tls_drv_control(ErlDrvData handle,
    switch (command)
    {
       case SET_CERTIFICATE_FILE_ACCEPT:
-      case SET_CERTIFICATE_FILE_CONNECT:
-        d->ctx = SSL_CTX_new(SSLv23_method());
-        die_unless(d->ctx, "SSL_CTX_new failed");
+      case SET_CERTIFICATE_FILE_CONNECT: {
+        time_t mtime = 0;
+        SSL_CTX *ssl_ctx = hash_table_lookup(buf, &mtime);
+        if (is_key_file_modified(buf, &mtime) || ssl_ctx == NULL)
+        {
+           SSL_CTX *ctx;
 
-        res = SSL_CTX_use_certificate_chain_file(d->ctx, buf);
-        die_unless(res > 0, "SSL_CTX_use_certificate_file failed");
+           hash_table_insert(buf, mtime, NULL);
 
-        res = SSL_CTX_use_PrivateKey_file(d->ctx, buf, SSL_FILETYPE_PEM);
-        die_unless(res > 0, "SSL_CTX_use_PrivateKey_file failed");
+           ctx = SSL_CTX_new(SSLv23_method());
+           die_unless(ctx, "SSL_CTX_new failed");
 
-        res = SSL_CTX_check_private_key(d->ctx);
-        die_unless(res > 0, "SSL_CTX_check_private_key failed");
+           res = SSL_CTX_use_certificate_chain_file(ctx, buf);
+           die_unless(res > 0, "SSL_CTX_use_certificate_file failed");
 
-        SSL_CTX_set_default_verify_paths(d->ctx);
+           res = SSL_CTX_use_PrivateKey_file(ctx, buf, SSL_FILETYPE_PEM);
+           die_unless(res > 0, "SSL_CTX_use_PrivateKey_file failed");
 
-        if (command == SET_CERTIFICATE_FILE_ACCEPT)
-        {
-           SSL_CTX_set_verify(d->ctx,
-                              SSL_VERIFY_PEER|SSL_VERIFY_CLIENT_ONCE,
-                              verify_callback);
+           res = SSL_CTX_check_private_key(ctx);
+           die_unless(res > 0, "SSL_CTX_check_private_key failed");
+
+           SSL_CTX_set_default_verify_paths(ctx);
+
+           if (command == SET_CERTIFICATE_FILE_ACCEPT)
+           {
+              SSL_CTX_set_verify(ctx,
+                                 SSL_VERIFY_PEER|SSL_VERIFY_CLIENT_ONCE,
+                                 verify_callback);
+           }
+
+           ssl_ctx = ctx;
+           hash_table_insert(buf, mtime, ssl_ctx);
         }
-        
-        d->ssl = SSL_new(d->ctx);
+
+        d->ssl = SSL_new(ssl_ctx);
         die_unless(d->ssl, "SSL_new failed");
 
         d->bio_read = BIO_new(BIO_s_mem());
@@ -154,9 +361,12 @@ static int tls_drv_control(ErlDrvData handle,
 
         if (command == SET_CERTIFICATE_FILE_ACCEPT)
            SSL_set_accept_state(d->ssl);
-        else
+        else {
+           SSL_set_options(d->ssl, SSL_OP_NO_SSLv2);
            SSL_set_connect_state(d->ssl);
+        }
         break;
+      }
       case SET_ENCRYPTED_INPUT:
         die_unless(d->ssl, "SSL not initialized");
         BIO_write(d->bio_read, buf, len);
@@ -282,7 +492,7 @@ ErlDrvEntry tls_driver_entry = {
    NULL,                       /* F_PTR ready_input, called when input descriptor ready */
    NULL,                       /* F_PTR ready_output, called when output descriptor ready */
    "tls_drv",                  /* char *driver_name, the argument to open_port */
-   NULL,                       /* F_PTR finish, called when unloaded */
+   tls_drv_finish,             /* F_PTR finish, called when unloaded */
    NULL,                       /* handle */
    tls_drv_control,            /* F_PTR control, port_command callback */
    NULL,                       /* F_PTR timeout, reserved */
@@ -293,6 +503,7 @@ DRIVER_INIT(tls_drv) /* must match name in driver_entry */
 {
    OpenSSL_add_ssl_algorithms();
    SSL_load_error_strings();
+   init_hash_table();
    return &tls_driver_entry;
 }