Whamcloud - gitweb
LU-14479 ssk: explicitly set perm on key
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
old mode 100644 (file)
new mode 100755 (executable)
index 2dc453e..fb37758
@@ -22,6 +22,8 @@
 /*
  * Copyright (C) 2015, Trustees of Indiana University
  *
+ * Copyright (c) 2016, 2017, Intel Corporation.
+ *
  * Author: Jeremy Filizetti <jfilizet@iu.edu>
  */
 
 #include <openssl/hmac.h>
 #include <sys/types.h>
 #include <sys/stat.h>
-#include <lnet/nidstr.h>
+#include <libcfs/util/string.h>
+#include <sys/time.h>
 
 #include "sk_utils.h"
 #include "write_bytes.h"
 
-static struct sk_crypt_type sk_crypt_types[] = {
-       [SK_CRYPT_AES256_CTR] = {
-               .sct_name = "ctr(aes)",
-               .sct_bytes = 32,
-       },
-};
-
-static struct sk_hmac_type sk_hmac_types[] = {
-       [SK_HMAC_SHA256] = {
-               .sht_name = "hmac(sha256)",
-               .sht_bytes = 32,
-       },
-       [SK_HMAC_SHA512] = {
-               .sht_name = "hmac(sha512)",
-               .sht_bytes = 64,
-       },
-};
+#define SK_PBKDF2_ITERATIONS 10000
 
 #ifdef _NEW_BUILD_
 # include "lgss_utils.h"
@@ -107,14 +94,22 @@ struct sk_keyfile_config *sk_read_file(char *filename)
 
        /* allow standard input override */
        if (strcmp(filename, "-") == 0)
-               fd = dup(STDIN_FILENO);
+               fd = STDIN_FILENO;
        else
                fd = open(filename, O_RDONLY);
 
        if (fd == -1) {
-               printerr(0, "Error opening file %s: %s\n", filename,
+               printerr(0, "Error opening key file '%s': %s\n", filename,
                         strerror(errno));
                goto out_free;
+       } else if (fd != STDIN_FILENO) {
+               struct stat st;
+
+               rc = fstat(fd, &st);
+               if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
+                       fprintf(stderr, "warning: "
+                               "secret key '%s' has insecure file mode %#o\n",
+                               filename, st.st_mode);
        }
 
        ptr = (char *)config;
@@ -136,7 +131,8 @@ struct sk_keyfile_config *sk_read_file(char *filename)
                remain -= rc;
        }
 
-       close(fd);
+       if (fd != STDIN_FILENO)
+               close(fd);
        sk_config_disk_to_cpu(config);
        return config;
 
@@ -178,11 +174,18 @@ static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
 
        key = add_key("user", description, &payload, sizeof(payload),
                      KEY_SPEC_USER_KEYRING);
-       if (key != -1)
+       if (key != -1) {
+               key_perm_t perm = KEY_POS_ALL | KEY_USR_ALL |
+                       KEY_GRP_ALL | KEY_OTH_ALL;
+
+               if (keyctl_setperm(key, perm) < 0)
+                       printerr(2, "Failed to set perm 0x%x on key %d\n",
+                                perm, key);
                printerr(2, "Added key %d with description %s\n", key,
                         description);
-       else
+       } else {
                printerr(0, "Failed to add key with %s\n", description);
+       }
 
        return key;
 }
@@ -198,7 +201,7 @@ static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
  * \return     0       sucess
  * \return     -1      failure
  */
-int sk_load_keyfile(char *path, int type)
+int sk_load_keyfile(char *path)
 {
        struct sk_keyfile_config *config;
        char description[SK_DESCRIPTION_SIZE + 1];
@@ -231,7 +234,7 @@ int sk_load_keyfile(char *path, int type)
        /* The server side can have multiple key files per file system so
         * the nodemap name is appended to the key description to uniquely
         * identify it */
-       if (type & SK_TYPE_MGS) {
+       if (config->skc_type & SK_TYPE_MGS) {
                /* Any key can be an MGS key as long as we are told to use it */
                rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
                              config->skc_nodemap);
@@ -240,7 +243,7 @@ int sk_load_keyfile(char *path, int type)
                if (sk_load_key(config, description) == -1)
                        goto out;
        }
-       if (type & SK_TYPE_SERVER) {
+       if (config->skc_type & SK_TYPE_SERVER) {
                /* Server keys need to have the file system name in the key */
                if (!config->skc_fsname) {
                        printerr(0, "Key configuration has no file system "
@@ -254,7 +257,7 @@ int sk_load_keyfile(char *path, int type)
                if (sk_load_key(config, description) == -1)
                        goto out;
        }
-       if (type & SK_TYPE_CLIENT) {
+       if (config->skc_type & SK_TYPE_CLIENT) {
                /* Load client file system key */
                if (config->skc_fsname) {
                        rc = snprintf(description, SK_DESCRIPTION_SIZE,
@@ -303,12 +306,10 @@ void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
        config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
        config->skc_expire = htobe32(config->skc_expire);
        config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
-       config->skc_session_keylen = htobe32(config->skc_session_keylen);
+       config->skc_prime_bits = htobe32(config->skc_prime_bits);
 
        for (i = 0; i < MAX_MGSNIDS; i++)
                config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
-
-       return;
 }
 
 /**
@@ -328,12 +329,10 @@ void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
        config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
        config->skc_expire = be32toh(config->skc_expire);
        config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
-       config->skc_session_keylen = be32toh(config->skc_session_keylen);
+       config->skc_prime_bits = be32toh(config->skc_prime_bits);
 
        for (i = 0; i < MAX_MGSNIDS; i++)
                config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
-
-       return;
 }
 
 /**
@@ -352,18 +351,22 @@ int sk_validate_config(const struct sk_keyfile_config *config)
                printerr(0, "Null configuration passed\n");
                return -1;
        }
+
        if (config->skc_version != SK_CONF_VERSION) {
                printerr(0, "Invalid version\n");
                return -1;
        }
-       if (config->skc_hmac_alg >= SK_HMAC_MAX) {
+
+       if (config->skc_hmac_alg == SK_HMAC_INVALID) {
                printerr(0, "Invalid HMAC algorithm\n");
                return -1;
        }
-       if (config->skc_crypt_alg >= SK_CRYPT_MAX) {
+
+       if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
                printerr(0, "Invalid crypt algorithm\n");
                return -1;
        }
+
        if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
                /* Try to limit key expiration to some reasonable minimum and
                 * also prevent values over INT_MAX because there appears
@@ -372,11 +375,11 @@ int sk_validate_config(const struct sk_keyfile_config *config)
                         "and %d\n", 60, INT_MAX);
                return -1;
        }
-       if (config->skc_session_keylen % 8 != 0 ||
-           config->skc_session_keylen > SK_SESSION_MAX_KEYLEN_BYTES * 8) {
+       if (config->skc_prime_bits % 8 != 0 ||
+           config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
                printerr(0, "Invalid session key length must be a multiple of 8"
                         " and less then %d bits\n",
-                        SK_SESSION_MAX_KEYLEN_BYTES * 8);
+                        SK_MAX_P_BYTES * 8);
                return -1;
        }
        if (config->skc_shared_keylen % 8 != 0 ||
@@ -404,6 +407,11 @@ int sk_validate_config(const struct sk_keyfile_config *config)
                return -1;
        }
 
+       if (config->skc_type == SK_TYPE_INVALID) {
+               printerr(0, "Invalid key type\n");
+               return -1;
+       }
+
        return 0;
 }
 
@@ -643,12 +651,11 @@ struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
 
        kctx = &skc->sc_kctx;
        kctx->skc_version = config->skc_version;
-       kctx->skc_hmac_alg = config->skc_hmac_alg;
-       kctx->skc_crypt_alg = config->skc_crypt_alg;
+       strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
+       strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
        kctx->skc_expire = config->skc_expire;
 
        /* key payload format is in bits, convert to bytes */
-       skc->sc_session_keylen = config->skc_session_keylen / 8;
        kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
        kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
        if (!kctx->skc_shared_key.value) {
@@ -658,214 +665,312 @@ struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
        memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
               kctx->skc_shared_key.length);
 
+       skc->sc_p.length = config->skc_prime_bits / 8;
+       skc->sc_p.value = malloc(skc->sc_p.length);
+       if (!skc->sc_p.value) {
+               printerr(0, "Failed to allocate p\n");
+               goto out_err;
+       }
+       memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
+
        free(config);
 
        return skc;
 
 out_err:
-       if (skc)
-               sk_free_cred(skc);
+       sk_free_cred(skc);
 
        free(config);
        return NULL;
 }
 
-/**
- * Generates a public key and computes the private key for the DH key exchange.
- * The parameters must be populated with the p and g from the peer.
+#define SK_GENERATOR 2
+#define DH_NUMBER_ITERATIONS_FOR_PRIME 64
+
+/* OpenSSL 1.1.1c increased the number of rounds used for Miller-Rabin testing
+ * of the prime provided as input parameter to DH_check(). This makes the check
+ * roughly x10 longer, and causes request timeouts when an SSK flavor is being
+ * used.
  *
- * \param[in,out]      skc     Shared key credentials structure to populate
- *                             with DH parameters
+ * Instead, use a dynamic number Miller-Rabin rounds based on the speed of the
+ * check on the current system, evaluated when the lsvcgssd daemon starts, but
+ * at least as many as OpenSSL 1.1.1b used for the same key size. If default
+ * DH_check() duration is OK, use it directly instead of limiting the rounds.
  *
- * \retval     GSS_S_COMPLETE  success
- * \retval     GSS_S_FAILURE   failure
+ * If \a num_rounds == 0, we just call original DH_check() directly.
  */
-static uint32_t sk_gen_responder_params(struct sk_cred *skc)
+static bool sk_is_dh_valid(const DH *dh, int num_rounds)
 {
+       const BIGNUM *p, *g;
+       BN_ULONG word;
+       BN_CTX *ctx;
+       BIGNUM *r;
+       bool valid = false;
        int rc;
 
-       /* No keys to generate without privacy mode */
-       if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
-               return GSS_S_COMPLETE;
+       if (num_rounds == 0) {
+               int codes = 0;
 
-       skc->sc_params = DH_new();
-       if (!skc->sc_params) {
-               printerr(0, "Failed to allocate DH\n");
-               return GSS_S_FAILURE;
+               rc = DH_check(dh, &codes);
+               if (rc != 1 || codes) {
+                       printerr(0, "DH_check(0) failed: codes=%#x: rc=%d\n",
+                                codes, rc);
+                       return false;
+               }
+               return true;
        }
 
-       /* responder should already have sc_p populated */
-       skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
-       if (!skc->sc_params->p) {
-               printerr(0, "Failed to convert binary to BIGNUM\n");
-               return GSS_S_FAILURE;
-       }
+       DH_get0_pqg(dh, &p, NULL, &g);
 
-       /* and we use a static generator for shared key */
-       skc->sc_params->g = BN_new();
-       if (!skc->sc_params->g) {
-               printerr(0, "Failed to allocate new BIGNUM\n");
-               return GSS_S_FAILURE;
-       }
-       if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
-               printerr(0, "Failed to set g value for DH params\n");
-               return GSS_S_FAILURE;
+       if (!BN_is_word(g, SK_GENERATOR)) {
+               printerr(0, "%s: Diffie-Hellman generator is not %u\n",
+                        program_invocation_short_name, SK_GENERATOR);
+               return false;
        }
 
-       /* verify that we have a safe prime and valid generator */
-       if (DH_check(skc->sc_params, &rc) != 1) {
-               printerr(0, "DH_check() failed: %d\n", rc);
-               return GSS_S_FAILURE;
-       } else if (rc) {
-               printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
-               return GSS_S_FAILURE;
+       word = BN_mod_word(p, 24);
+       if (word != 11) {
+               printerr(0, "%s: Diffie-Hellman prime modulo=%lu unsuitable\n",
+                        program_invocation_short_name, word);
+               return false;
        }
 
-       if (DH_generate_key(skc->sc_params) != 1) {
-               printerr(0, "Failed to generate public DH key: %s\n",
-                        ERR_error_string(ERR_get_error(), NULL));
-               return GSS_S_FAILURE;
+       ctx = BN_CTX_new();
+       if (ctx == NULL) {
+               printerr(0, "%s: Diffie-Hellman error allocating context\n",
+                        program_invocation_short_name);
+               return false;
        }
+       BN_CTX_start(ctx);
+       r = BN_CTX_get(ctx); /* must be called before "ctx" used elsewhere */
 
-       skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
-       skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
-       if (!skc->sc_pub_key.value) {
-               printerr(0, "Failed to allocate memory for public key\n");
-               return GSS_S_FAILURE;
+       rc = BN_is_prime_ex(p, num_rounds, ctx, NULL);
+       if (rc == 0)
+               printerr(0, "%s: Diffie-Hellman 'p' not prime in %u rounds\n",
+                        program_invocation_short_name, num_rounds);
+       if (rc <= 0)
+               goto out_free;
+
+       if (!BN_rshift1(r, p)) {
+               printerr(0, "%s: error shifting BigNum 'r' by 'p'\n",
+                        program_invocation_short_name);
+               goto out_free;
        }
+       rc = BN_is_prime_ex(r, num_rounds, ctx, NULL);
+       if (rc == 0)
+               printerr(0, "%s: Diffie-Hellman 'r' not prime in %u rounds\n",
+                        program_invocation_short_name, num_rounds);
+       if (rc <= 0)
+               goto out_free;
 
-       BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
+       valid = true;
 
-       return GSS_S_COMPLETE;
+out_free:
+       BN_CTX_end(ctx);
+       BN_CTX_free(ctx);
+
+       return valid;
 }
 
-static void sk_free_parameters(struct sk_cred *skc)
+#define VALUE_LENGTH 256
+static unsigned char test_prime[VALUE_LENGTH] =
+       "\xf7\xfa\x49\xd8\xec\xb1\x3b\xff\x26\x10\x3f\xc5\x3a\xc5\xcc\x40"
+       "\x4f\xbf\x92\xe1\x8b\x83\xe7\xa2\xba\x0f\x51\x5a\x91\x48\xe0\xa3"
+       "\xf1\x4d\xbc\xbb\x8a\x28\x14\xac\x02\x23\x76\x42\x17\x4d\x3c\xdc"
+       "\x5e\x4f\x80\x1f\xd7\x54\x1c\x50\xac\x3b\x28\x68\x8d\x71\x41\x7f"
+       "\xa7\x1c\x2f\x22\xd3\xa8\x91\xb2\x64\xb6\x84\xa6\xcf\x06\x16\x91"
+       "\x2f\xb8\xb4\x42\x1d\x3a\x4e\x3a\x0c\x7f\x04\x69\x78\xb5\x8f\x92"
+       "\x07\x89\xac\x24\x06\x53\x2c\x23\xec\xaa\x5c\xb4\x7b\x49\xbc\xf4"
+       "\x90\x67\x71\x9c\x24\x2c\x1d\x8d\x76\xc8\x85\x4e\x19\xf1\xf9\x33"
+       "\x45\xbd\x9f\x7d\x0a\x08\x8c\x22\xcc\x35\xf3\x5b\xab\x3f\x24\x9d"
+       "\x61\x70\x86\xbb\xbe\xd8\xb0\xf8\x34\xfa\xeb\x5b\x8e\xf2\x62\x23"
+       "\xd1\xfb\xbb\xb8\x21\x71\x1e\x39\x39\x59\xe0\x82\x98\x41\x84\x40"
+       "\x1f\xd3\x9b\xa3\x73\xdb\xec\xe0\xc0\xde\x2d\x1c\xea\x43\x40\x93"
+       "\x98\x38\x03\x36\x1e\xe1\xe7\x39\x7b\x35\x92\x4a\x51\xa5\x91\x63"
+       "\xd5\x31\x98\x3d\x89\x27\x6f\xcc\x69\xff\xbe\x31\x13\xdc\x2f\x72"
+       "\x2d\xab\x6a\xb7\x13\xd3\x47\xda\xaa\xf3\x3c\xa0\xfd\xaa\x0f\x02"
+       "\x96\x81\x1a\x26\xe8\xf7\x25\x65\x33\x78\xd9\x6b\x6d\xb0\xd9\xfb";
+
+/**
+ * Measure time taken by prime testing routine for a 2048 bit long prime,
+ * depending on the number of check rounds.
+ *
+ * \param[in]  usec_check_max    max time allowed for DH_check completion
+ *
+ * \retval     max number of rounds to keep prime testing under usec_check_max
+ *             return 0 if we should use the default DH_check rounds
+ */
+int sk_speedtest_dh_valid(unsigned int usec_check_max)
 {
-       if (skc->sc_params)
-               DH_free(skc->sc_params);
-       if (skc->sc_p.value)
-               free(skc->sc_p.value);
-       if (skc->sc_pub_key.value)
-               free(skc->sc_pub_key.value);
+       DH *dh;
+       BIGNUM *p, *g;
+       int num_rounds, prev_rounds = 0;
+
+       dh = DH_new();
+       if (!dh)
+               return 0;
+
+       p = BN_bin2bn(test_prime, VALUE_LENGTH, NULL);
+       if (!p)
+               goto free_dh;
+
+       g = BN_new();
+       if (!g)
+               goto free_p;
+
+       if (!BN_set_word(g, SK_GENERATOR))
+               goto free_g;
+
+       /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
+       if (!DH_set0_pqg(dh, p, NULL, g)) {
+       free_g:
+               BN_free(g);
+       free_p:
+               BN_free(p);
+               goto free_dh;
+       }
+
+       for (num_rounds = 0;
+            num_rounds <= DH_NUMBER_ITERATIONS_FOR_PRIME;
+            num_rounds += (num_rounds <= 4 ? 4 : 8)) {
+               unsigned int usec_this;
+               int j;
+
+               /* get max duration of 4 runs at current number of rounds */
+               usec_this = 0;
+               for (j = 0; j < 4; j++) {
+                       struct timeval now, prev;
+                       unsigned int usec_curr;
+
+                       gettimeofday(&prev, NULL);
+                       if (!sk_is_dh_valid(dh, num_rounds)) {
+                               /* if test_prime is found bad, use default */
+                               prev_rounds = 0;
+                               goto free_dh;
+                       }
+                       gettimeofday(&now, NULL);
+                       usec_curr = (now.tv_sec - prev.tv_sec) * 1000000 +
+                                   now.tv_usec - prev.tv_usec;
+                       if (usec_curr > usec_this)
+                               usec_this = usec_curr;
+               }
+               printerr(2, "%s: %d rounds: %d usec\n",
+                        program_invocation_short_name, num_rounds, usec_this);
+               if (num_rounds == 0) {
+                       if (usec_this <= usec_check_max)
+                       /* using original check rounds as implemented in
+                        * DH_check() took less time than the max allowed,
+                        * so just use original DH_check()
+                        */
+                               break;
+               } else if (usec_this >= usec_check_max) {
+                       break;
+               }
+               prev_rounds = num_rounds;
+       }
+
+free_dh:
+       DH_free(dh);
 
-       skc->sc_p.value = NULL;
-       skc->sc_p.length = 0;
-       skc->sc_pub_key.value = NULL;
-       skc->sc_pub_key.length = 0;
+       return prev_rounds;
 }
 
 /**
- * Generates shared key Diffie Hellman parameters used for the DH key exchange
- * between host and peer if privacy mode is enabled
+ * Populates the DH parameters for the DHKE
  *
- * \param[in,out]      skc     Shared key credentials structure to populate
- *                             with DH parameters
+ * \param[in,out]      skc             Shared key credentials structure to
+ *                                     populate with DH parameters
  *
  * \retval     GSS_S_COMPLETE  success
  * \retval     GSS_S_FAILURE   failure
  */
-static uint32_t sk_gen_initiator_params(struct sk_cred *skc)
+uint32_t sk_gen_params(struct sk_cred *skc, int num_rounds)
 {
-       gss_buffer_desc *iv = &skc->sc_kctx.skc_iv;
-       int rc;
-
-       /* The credential could be used so free existing parameters */
-       sk_free_parameters(skc);
-
-       /* Pseudo random should be sufficient here because the IV will be used
-        * with a key that is used only once.  This also should ensure we have
-        * unqiue tokens that are sent to the remote server which is important
-        * because the token is hashed for the sunrpc cache lookups and a
-        * failure there would cause connection attempts to fail indefinitely
-        * due to the large timeout value on the server side sunrpc cache
-        * (INT_MAX) */
-       iv->length = SK_IV_SIZE;
-       iv->value = malloc(iv->length);
-       if (!iv->value) {
-               printerr(0, "Failed to allocate memory for IV\n");
+       uint32_t random;
+       BIGNUM *p, *g;
+       const BIGNUM *pub_key;
+
+       /* Random value used by both the request and response as part of the
+        * key binding material.  This also should ensure we have unqiue
+        * tokens that are sent to the remote server which is important because
+        * the token is hashed for the sunrpc cache lookups and a failure there
+        * would cause connection attempts to fail indefinitely due to the large
+        * timeout value on the server side */
+       if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
+               printerr(0, "Failed to get data for random parameter: %s\n",
+                        ERR_error_string(ERR_get_error(), NULL));
                return GSS_S_FAILURE;
        }
-       memset(iv->value, 0, iv->length);
-       if (RAND_bytes(iv->value, iv->length) != 1) {
-               printerr(0, "Failed to get data for IV\n");
+
+       /* The random value will always be used in byte range operations
+        * so we keep it as big endian from this point on */
+       skc->sc_kctx.skc_host_random = random;
+
+       /* Populate DH parameters */
+       skc->sc_params = DH_new();
+       if (!skc->sc_params) {
+               printerr(0, "Failed to allocate DH\n");
                return GSS_S_FAILURE;
        }
 
-       /* Only privacy mode needs the rest of the parameter generation
-        * but we use IV in other modes as well so tokens should be
-        * unique */
-       if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
-               return GSS_S_COMPLETE;
-
-       skc->sc_params = DH_generate_parameters(skc->sc_session_keylen * 8,
-                                               SK_GENERATOR, NULL, NULL);
-       if (skc->sc_params == NULL) {
-               printerr(0, "Failed to generate diffie-hellman parameters: %s",
-                        ERR_error_string(ERR_get_error(), NULL));
+       p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
+       if (!p) {
+               printerr(0, "Failed to convert binary to BIGNUM\n");
                return GSS_S_FAILURE;
        }
 
-       if (DH_check(skc->sc_params, &rc) != 1) {
-               printerr(0, "DH_check() failed: %d\n", rc);
+       /* We use a static generator for shared key */
+       g = BN_new();
+       if (!g) {
+               printerr(0, "Failed to allocate new BIGNUM\n");
+               return GSS_S_FAILURE;
+       }
+       if (BN_set_word(g, SK_GENERATOR) != 1) {
+               printerr(0, "Failed to set g value for DH params\n");
                return GSS_S_FAILURE;
-       } else if (rc) {
-               printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
+       }
+
+       if (!DH_set0_pqg(skc->sc_params, p, NULL, g)) {
+               printerr(0, "Failed to set pqg\n");
                return GSS_S_FAILURE;
        }
 
+       /* Verify that we have a safe prime and valid generator */
+       if (!sk_is_dh_valid(skc->sc_params, num_rounds))
+               return GSS_S_FAILURE;
+
        if (DH_generate_key(skc->sc_params) != 1) {
                printerr(0, "Failed to generate public DH key: %s\n",
                         ERR_error_string(ERR_get_error(), NULL));
                return GSS_S_FAILURE;
        }
 
-       skc->sc_p.length = BN_num_bytes(skc->sc_params->p);
-       skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
-       skc->sc_p.value = malloc(skc->sc_p.length);
+       DH_get0_key(skc->sc_params, &pub_key, NULL);
+       skc->sc_pub_key.length = BN_num_bytes(pub_key);
        skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
-       if (!skc->sc_p.value || !skc->sc_pub_key.value) {
-               printerr(0, "Failed to allocate memory for params\n");
+       if (!skc->sc_pub_key.value) {
+               printerr(0, "Failed to allocate memory for public key\n");
                return GSS_S_FAILURE;
        }
 
-       BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
-       BN_bn2bin(skc->sc_params->p, skc->sc_p.value);
+       BN_bn2bin(pub_key, skc->sc_pub_key.value);
 
        return GSS_S_COMPLETE;
 }
 
 /**
- * Generates or populates the DH parameters depending on whether the system is
- * the initiator or responder for the connection
- *
- * \param[in,out]      skc             Shared key credentials structure to
- *                                     populate with DH parameters
- * \param[in]          initiator       Boolean whether to initiate parameters
- *
- * \retval     GSS_S_COMPLETE  success
- * \retval     GSS_S_FAILURE   failure
- */
-uint32_t sk_gen_params(struct sk_cred *skc, const bool initiator)
-{
-       if (initiator)
-               return sk_gen_initiator_params(skc);
-
-       return sk_gen_responder_params(skc);
-}
-
-/**
  * Convert SK hash algorithm into openssl message digest
  *
  * \param[in,out]      alg             SK hash algorithm
  *
  * \retval             EVP_MD
  */
-static inline const EVP_MD *sk_hash_to_evp_md(enum sk_hmac_alg alg)
+static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
 {
        switch (alg) {
-       case SK_HMAC_SHA256:
+       case CFS_HASH_ALG_SHA256:
                return EVP_sha256();
-       case SK_HMAC_SHA512:
+       case CFS_HASH_ALG_SHA512:
                return EVP_sha512();
        default:
                return EVP_md_null();
@@ -890,7 +995,7 @@ static inline const EVP_MD *sk_hash_to_evp_md(enum sk_hmac_alg alg)
 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
                 const EVP_MD *hash_alg, gss_buffer_desc *hmac)
 {
-       HMAC_CTX hctx;
+       HMAC_CTX *hctx;
        unsigned int hashlen = EVP_MD_size(hash_alg);
        int i;
        int rc = -1;
@@ -900,7 +1005,7 @@ int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
                return -1;
        }
 
-       HMAC_CTX_init(&hctx);
+       hctx = HMAC_CTX_new();
 
        hmac->length = hashlen;
        hmac->value = malloc(hashlen);
@@ -909,30 +1014,23 @@ int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
                goto out;
        }
 
-#ifdef HAVE_VOID_OPENSSL_HMAC_FUNCS
-       HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL);
-       for (i = 0; i < numbufs; i++)
-               HMAC_Update(&hctx, bufs[i].value, bufs[i].length);
-       HMAC_Final(&hctx, hmac->value, &hashlen);
-#else
-       if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
+       if (HMAC_Init_ex(hctx, key->value, key->length, hash_alg, NULL) != 1) {
                printerr(0, "Failed to init HMAC\n");
                goto out;
        }
 
        for (i = 0; i < numbufs; i++) {
-               if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
+               if (HMAC_Update(hctx, bufs[i].value, bufs[i].length) != 1) {
                        printerr(0, "Failed to update HMAC\n");
                        goto out;
                }
        }
 
        /* The result gets populated in hmac */
-       if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
+       if (HMAC_Final(hctx, hmac->value, &hashlen) != 1) {
                printerr(0, "Failed to finalize HMAC\n");
                goto out;
        }
-#endif
 
        if (hmac->length != hashlen) {
                printerr(0, "HMAC size does not match expected\n");
@@ -941,7 +1039,7 @@ int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
 
        rc = 0;
 out:
-       HMAC_CTX_cleanup(&hctx);
+       HMAC_CTX_free(hctx);
        return rc;
 }
 
@@ -995,6 +1093,9 @@ uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
  */
 void sk_free_cred(struct sk_cred *skc)
 {
+       if (!skc)
+               return;
+
        if (skc->sc_p.value)
                free(skc->sc_p.value);
        if (skc->sc_pub_key.value)
@@ -1012,16 +1113,21 @@ void sk_free_cred(struct sk_cred *skc)
                       skc->sc_dh_shared_key.length);
                free(skc->sc_dh_shared_key.value);
        }
+       if (skc->sc_kctx.skc_hmac_key.value) {
+               memset(skc->sc_kctx.skc_hmac_key.value, 0,
+                      skc->sc_kctx.skc_hmac_key.length);
+               free(skc->sc_kctx.skc_hmac_key.value);
+       }
+       if (skc->sc_kctx.skc_encrypt_key.value) {
+               memset(skc->sc_kctx.skc_encrypt_key.value, 0,
+                      skc->sc_kctx.skc_encrypt_key.length);
+               free(skc->sc_kctx.skc_encrypt_key.value);
+       }
        if (skc->sc_kctx.skc_shared_key.value) {
                memset(skc->sc_kctx.skc_shared_key.value, 0,
                       skc->sc_kctx.skc_shared_key.length);
                free(skc->sc_kctx.skc_shared_key.value);
        }
-       if (skc->sc_kctx.skc_iv.value) {
-               memset(skc->sc_kctx.skc_iv.value, 0,
-                      skc->sc_kctx.skc_iv.length);
-               free(skc->sc_kctx.skc_iv.value);
-       }
        if (skc->sc_kctx.skc_session_key.value) {
                memset(skc->sc_kctx.skc_session_key.value, 0,
                       skc->sc_kctx.skc_session_key.length);
@@ -1032,6 +1138,69 @@ void sk_free_cred(struct sk_cred *skc)
                DH_free(skc->sc_params);
 
        free(skc);
+       skc = NULL;
+}
+
+/* This function handles key derivation using the hash algorithm specified in
+ * \a hash_alg, buffers in \a key_binding_bufs, and original key in
+ * \a origin_key to produce a \a derived_key.  The first element of the
+ * key_binding_bufs array is reserved for the counter used in the KDF.  The
+ * derived key in \a derived_key could differ in size from \a origin_key and
+ * must be populated with the expected size and a valid buffer to hold the
+ * contents.
+ *
+ * If the derived key size is greater than the HMAC algorithm size it will be
+ * a done using several iterations of a counter and the key binding bufs.
+ *
+ * If the size is smaller it will take copy the first N bytes necessary to
+ * fill the derived key. */
+int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
+          gss_buffer_desc *key_binding_bufs, int numbufs,
+          enum cfs_crypto_hash_alg hmac_alg)
+{
+       size_t remain;
+       size_t bytes;
+       uint32_t counter;
+       char *keydata;
+       gss_buffer_desc tmp_hash;
+       int i;
+       int rc;
+
+       if (numbufs < 1)
+               return -EINVAL;
+
+       /* Use a counter as the first buffer followed by the key binding
+        * buffers in the event we need more than one a single cycle to
+        * produced a symmetric key large enough in size */
+       key_binding_bufs[0].value = &counter;
+       key_binding_bufs[0].length = sizeof(counter);
+
+       remain = derived_key->length;
+       keydata = derived_key->value;
+       i = 0;
+       while (remain > 0) {
+               counter = htobe32(i++);
+               rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
+                                 sk_hash_to_evp_md(hmac_alg), &tmp_hash);
+               if (rc) {
+                       if (tmp_hash.value)
+                               free(tmp_hash.value);
+                       return rc;
+               }
+
+               if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
+                       free(tmp_hash.value);
+                       return -EINVAL;
+               }
+
+               bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
+               memcpy(keydata, tmp_hash.value, bytes);
+               free(tmp_hash.value);
+               remain -= bytes;
+               keydata += bytes;
+       }
+
+       return 0;
 }
 
 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
@@ -1043,67 +1212,105 @@ void sk_free_cred(struct sk_cred *skc)
  * \return     -1              failure
  * \return     0               success
  */
-int sk_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
-          gss_buffer_desc *key_binding_input)
+int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
+                  gss_buffer_desc *client_token, gss_buffer_desc *server_token)
 {
        struct sk_kernel_ctx *kctx = &skc->sc_kctx;
        gss_buffer_desc *session_key = &kctx->skc_session_key;
-       gss_buffer_desc bufs[4];
-       gss_buffer_desc tmp_hash;
-       char *skp;
-       size_t remain;
-       size_t bytes;
-       uint32_t counter;
-       int i;
+       gss_buffer_desc bufs[5];
+       enum cfs_crypto_crypt_alg crypt_alg;
        int rc = -1;
 
-       /* No keys computed unless privacy mode is in use */
-       if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
-               return 0;
-
-       session_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
+       crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
+       session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
        session_key->value = malloc(session_key->length);
        if (!session_key->value) {
                printerr(0, "Failed to allocate memory for session key\n");
                return rc;
        }
 
-       /* Use the HMAC algorithm provided by in the shared key file to derive
-        * a session key.  eg: HMAC(key, msg)
-        * key: the shared key provided in the shared key file
-        * msg is the bytes in the following order:
-        * 1. big_endian(counter)
-        * 2. DH shared key
-        * 3. Clients NIDs
-        * 4. key_binding_input */
-       bufs[0].value = &counter;
-       bufs[0].length = sizeof(counter);
+       /* Key binding info ordering
+        * 1. Reserved for counter
+        * 1. DH shared key
+        * 2. Client's NIDs
+        * 3. Client's token
+        * 4. Server's token */
+       bufs[0].value = NULL;
+       bufs[0].length = 0;
        bufs[1] = skc->sc_dh_shared_key;
        bufs[2].value = &client_nid;
        bufs[2].length = sizeof(client_nid);
-       bufs[3] = *key_binding_input;
+       bufs[3] = *client_token;
+       bufs[4] = *server_token;
 
-       remain = session_key->length;
-       skp = session_key->value;
-       i = 0;
-       while (remain > 0) {
-               counter = be32toh(i++);
-               rc = sk_sign_bufs(&kctx->skc_shared_key, bufs, 4,
-                            sk_hash_to_evp_md(kctx->skc_hmac_alg), &tmp_hash);
-               if (rc) {
-                       free(tmp_hash.value);
-                       return rc;
-               }
+       return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
+                     5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
+}
+
+/* Uses the session key to create an HMAC key and encryption key.  In
+ * integrity mode the session key used to generate the HMAC key uses
+ * session information which is available on the wire but by creating
+ * a session based HMAC key we can prevent potential replay as both the
+ * client and server have random numbers used as part of the key creation.
+ *
+ * The keys used for integrity and privacy are formulated as below using
+ * the session key that is the output of the key derivation function.  The
+ * HMAC algorithm is determined by the shared key algorithm selected in the
+ * key file.
+ *
+ * For ski mode:
+ * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
+ *
+ * For skpi mode:
+ * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
+ * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
+ *
+ * \param[in,out]      skc             Shared key credentials structure with
+ *
+ * \return     -1              failure
+ * \return     0               success
+ */
+int sk_compute_keys(struct sk_cred *skc)
+{
+       struct sk_kernel_ctx *kctx = &skc->sc_kctx;
+       gss_buffer_desc *session_key = &kctx->skc_session_key;
+       gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
+       gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
+       enum cfs_crypto_hash_alg hmac_alg;
+       enum cfs_crypto_crypt_alg crypt_alg;
+       char *encrypt = "Encrypt";
+       char *integrity = "Integrity";
+       int rc;
 
-               LASSERT(sk_hmac_types[kctx->skc_hmac_alg].sht_bytes ==
-                       tmp_hash.length);
+       hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
+       hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
+       hmac_key->value = malloc(hmac_key->length);
+       if (!hmac_key->value)
+               return -ENOMEM;
 
-               bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
-               memcpy(skp, tmp_hash.value, bytes);
-               free(tmp_hash.value);
-               remain -= bytes;
-               skp += bytes;
-       }
+       rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
+                              session_key->length, SK_PBKDF2_ITERATIONS,
+                              sk_hash_to_evp_md(hmac_alg),
+                              hmac_key->length, hmac_key->value);
+       if (rc == 0)
+               return -EINVAL;
+
+       /* Encryption key is only populated in privacy mode */
+       if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
+               return 0;
+
+       crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
+       encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
+       encrypt_key->value = malloc(encrypt_key->length);
+       if (!encrypt_key->value)
+               return -ENOMEM;
+
+       rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
+                              session_key->length, SK_PBKDF2_ITERATIONS,
+                              sk_hash_to_evp_md(hmac_alg),
+                              encrypt_key->length, encrypt_key->value);
+       if (rc == 0)
+               return -EINVAL;
 
        return 0;
 }
@@ -1119,17 +1326,13 @@ int sk_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
  * \return     gss error               failure
  * \return     GSS_S_COMPLETE          success
  */
-uint32_t sk_compute_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
+uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
 {
        gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
        BIGNUM *remote_pub_key;
        int status;
        uint32_t rc = GSS_S_FAILURE;
 
-       /* No keys computed unless privacy mode is in use */
-       if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
-               return GSS_S_COMPLETE;
-
        remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
        if (!remote_pub_key) {
                printerr(0, "Failed to convert binary to BIGNUM\n");
@@ -1152,10 +1355,26 @@ uint32_t sk_compute_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
                         ERR_error_string(ERR_get_error(), NULL));
                goto out_err;
        } else if (status < dh_shared->length) {
-               printerr(0, "DH_compute_key() returned a short key of %d "
-                        "bytes, expected: %zu\n", status, dh_shared->length);
-               rc = GSS_S_DEFECTIVE_TOKEN;
-               goto out_err;
+               /* there is around 1 chance out of 256 that the returned
+                * shared key is shorter than expected
+                */
+               if (status >= dh_shared->length - 2) {
+                       int shift = dh_shared->length - status;
+                       /* if the key is short by only 1 or 2 bytes, just
+                        * prepend it with 0s
+                        */
+                       memmove((void *)(dh_shared->value + shift),
+                               dh_shared->value, status);
+                       memset(dh_shared->value, 0, shift);
+               } else {
+                       /* if the key is really too short, return GSS_S_BAD_QOP
+                        * so that the caller can retry to generate
+                        */
+                       printerr(0, "DH_compute_key() returned a short key of %d bytes, expected: %zu\n",
+                                status, dh_shared->length);
+                       rc = GSS_S_BAD_QOP;
+                       goto out_err;
+               }
        }
 
        rc = GSS_S_COMPLETE;
@@ -1182,8 +1401,8 @@ int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
        char *p, *end;
        size_t bufsize;
 
-       bufsize = sizeof(*kctx) + kctx->skc_session_key.length +
-                 kctx->skc_iv.length + kctx->skc_shared_key.length;
+       bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
+                 kctx->skc_encrypt_key.length;
 
        ctx_token->value = malloc(bufsize);
        if (!ctx_token->value)
@@ -1201,11 +1420,13 @@ int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
                return -1;
        if (WRITE_BYTES(&p, end, kctx->skc_expire))
                return -1;
-       if (write_buffer(&p, end, &kctx->skc_shared_key))
+       if (WRITE_BYTES(&p, end, kctx->skc_host_random))
+               return -1;
+       if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
                return -1;
-       if (write_buffer(&p, end, &kctx->skc_iv))
+       if (write_buffer(&p, end, &kctx->skc_hmac_key))
                return -1;
-       if (write_buffer(&p, end, &kctx->skc_session_key))
+       if (write_buffer(&p, end, &kctx->skc_encrypt_key))
                return -1;
 
        printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
@@ -1317,7 +1538,7 @@ int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
        ptr = ns->value;
        for (i = 0; i < numbufs; i++) {
                /* size */
-               rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
+               rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
                ptr += rc;
 
                /* contents */