Whamcloud - gitweb
2e356a12757353ff9a566b93b2f900f5ecbc4e0c
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42 #include <libcfs/util/string.h>
43 #include <sys/time.h>
44
45 #include "sk_utils.h"
46 #include "write_bytes.h"
47
48 #define SK_PBKDF2_ITERATIONS 10000
49
50 #ifdef _NEW_BUILD_
51 # include "lgss_utils.h"
52 #else
53 # include "gss_util.h"
54 # include "gss_oids.h"
55 # include "err_util.h"
56 #endif
57
58 #ifdef _ERR_UTIL_H_
59 /**
60  * Initializes logging
61  * \param[in]   program         Program name to output
62  * \param[in]   verbose         Verbose flag
63  * \param[in]   fg              Whether or not to run in foreground
64  *
65  */
66 void sk_init_logging(char *program, int verbose, int fg)
67 {
68         initerr(program, verbose, fg);
69 }
70 #endif
71
72 /**
73  * Loads the key from \a filename and returns the struct sk_keyfile_config.
74  * It should be freed by the caller.
75  *
76  * \param[in]   filename                Disk or key payload data
77  *
78  * \return      sk_keyfile_config       sucess
79  * \return      NULL                    failure
80  */
81 struct sk_keyfile_config *sk_read_file(char *filename)
82 {
83         struct sk_keyfile_config *config;
84         char *ptr;
85         size_t rc;
86         size_t remain;
87         int fd;
88
89         config = malloc(sizeof(*config));
90         if (!config) {
91                 printerr(0, "Failed to allocate memory for config\n");
92                 return NULL;
93         }
94
95         /* allow standard input override */
96         if (strcmp(filename, "-") == 0)
97                 fd = STDIN_FILENO;
98         else
99                 fd = open(filename, O_RDONLY);
100
101         if (fd == -1) {
102                 printerr(0, "Error opening key file '%s': %s\n", filename,
103                          strerror(errno));
104                 goto out_free;
105         } else if (fd != STDIN_FILENO) {
106                 struct stat st;
107
108                 rc = fstat(fd, &st);
109                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
110                         fprintf(stderr, "warning: "
111                                 "secret key '%s' has insecure file mode %#o\n",
112                                 filename, st.st_mode);
113         }
114
115         ptr = (char *)config;
116         remain = sizeof(*config);
117         while (remain > 0) {
118                 rc = read(fd, ptr, remain);
119                 if (rc == -1) {
120                         if (errno == EINTR)
121                                 continue;
122                         printerr(0, "read() failed on %s: %s\n", filename,
123                                  strerror(errno));
124                         goto out_close;
125                 } else if (rc == 0) {
126                         printerr(0, "File %s does not have a complete key\n",
127                                  filename);
128                         goto out_close;
129                 }
130                 ptr += rc;
131                 remain -= rc;
132         }
133
134         if (fd != STDIN_FILENO)
135                 close(fd);
136         sk_config_disk_to_cpu(config);
137         return config;
138
139 out_close:
140         close(fd);
141 out_free:
142         free(config);
143         return NULL;
144 }
145
146 /**
147  * Checks if a key matching \a description is found in the keyring for
148  * logging purposes and then attempts to load \a payload of \a psize into a key
149  * with \a description.
150  *
151  * \param[in]   payload         Key payload
152  * \param[in]   psize           Payload size
153  * \param[in]   description     Description used for key in keyring
154  *
155  * \return      0       sucess
156  * \return      -1      failure
157  */
158 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
159                                 const char *description)
160 {
161         struct sk_keyfile_config payload;
162         key_serial_t key;
163
164         memcpy(&payload, skc, sizeof(*skc));
165
166         /* In the keyring use the disk layout so keyctl pipe can be used */
167         sk_config_cpu_to_disk(&payload);
168
169         /* Check to see if a key is already loaded matching description */
170         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
171         if (key != -1)
172                 printerr(2, "Key %d found in session keyring, replacing\n",
173                          key);
174
175         key = add_key("user", description, &payload, sizeof(payload),
176                       KEY_SPEC_USER_KEYRING);
177         if (key != -1)
178                 printerr(2, "Added key %d with description %s\n", key,
179                          description);
180         else
181                 printerr(0, "Failed to add key with %s\n", description);
182
183         return key;
184 }
185
186 /**
187  * Reads the key from \a path, verifies it and loads into the session keyring
188  * using a description determined by the the \a type.  Existing keys with the
189  * same description are replaced.
190  *
191  * \param[in]   path    Path to key file
192  * \param[in]   type    Type of key to load which determines the description
193  *
194  * \return      0       sucess
195  * \return      -1      failure
196  */
197 int sk_load_keyfile(char *path)
198 {
199         struct sk_keyfile_config *config;
200         char description[SK_DESCRIPTION_SIZE + 1];
201         struct stat buf;
202         int i;
203         int rc;
204         int rc2 = -1;
205
206         rc = stat(path, &buf);
207         if (rc == -1) {
208                 printerr(0, "stat() failed for file %s: %s\n", path,
209                          strerror(errno));
210                 return rc2;
211         }
212
213         config = sk_read_file(path);
214         if (!config)
215                 return rc2;
216
217         /* Similar to ssh, require adequate care of key files */
218         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
219                 printerr(0, "Shared key files must be read/writeable only by "
220                          "owner\n");
221                 return -1;
222         }
223
224         if (sk_validate_config(config))
225                 goto out;
226
227         /* The server side can have multiple key files per file system so
228          * the nodemap name is appended to the key description to uniquely
229          * identify it */
230         if (config->skc_type & SK_TYPE_MGS) {
231                 /* Any key can be an MGS key as long as we are told to use it */
232                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
233                               config->skc_nodemap);
234                 if (rc >= SK_DESCRIPTION_SIZE)
235                         goto out;
236                 if (sk_load_key(config, description) == -1)
237                         goto out;
238         }
239         if (config->skc_type & SK_TYPE_SERVER) {
240                 /* Server keys need to have the file system name in the key */
241                 if (!config->skc_fsname) {
242                         printerr(0, "Key configuration has no file system "
243                                  "attribute.  Can't load as server type\n");
244                         goto out;
245                 }
246                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
247                               config->skc_fsname, config->skc_nodemap);
248                 if (rc >= SK_DESCRIPTION_SIZE)
249                         goto out;
250                 if (sk_load_key(config, description) == -1)
251                         goto out;
252         }
253         if (config->skc_type & SK_TYPE_CLIENT) {
254                 /* Load client file system key */
255                 if (config->skc_fsname) {
256                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
257                                       "lustre:%s", config->skc_fsname);
258                         if (rc >= SK_DESCRIPTION_SIZE)
259                                 goto out;
260                         if (sk_load_key(config, description) == -1)
261                                 goto out;
262                 }
263
264                 /* Load client MGC keys */
265                 for (i = 0; i < MAX_MGSNIDS; i++) {
266                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
267                                 continue;
268                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
269                                       "lustre:MGC%s",
270                                       libcfs_nid2str(config->skc_mgsnids[i]));
271                         if (rc >= SK_DESCRIPTION_SIZE)
272                                 goto out;
273                         if (sk_load_key(config, description) == -1)
274                                 goto out;
275                 }
276         }
277
278         rc2 = 0;
279
280 out:
281         free(config);
282         return rc2;
283 }
284
285 /**
286  * Byte swaps config from cpu format to disk
287  *
288  * \param[in,out]       config          sk_keyfile_config to swap
289  */
290 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
291 {
292         int i;
293
294         if (!config)
295                 return;
296
297         config->skc_version = htobe32(config->skc_version);
298         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
299         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
300         config->skc_expire = htobe32(config->skc_expire);
301         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
302         config->skc_prime_bits = htobe32(config->skc_prime_bits);
303
304         for (i = 0; i < MAX_MGSNIDS; i++)
305                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
306 }
307
308 /**
309  * Byte swaps config from disk format to cpu
310  *
311  * \param[in,out]       config          sk_keyfile_config to swap
312  */
313 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
314 {
315         int i;
316
317         if (!config)
318                 return;
319
320         config->skc_version = be32toh(config->skc_version);
321         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
322         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
323         config->skc_expire = be32toh(config->skc_expire);
324         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
325         config->skc_prime_bits = be32toh(config->skc_prime_bits);
326
327         for (i = 0; i < MAX_MGSNIDS; i++)
328                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
329 }
330
331 /**
332  * Verifies the on key payload format is valid
333  *
334  * \param[in]   config          sk_keyfile_config
335  *
336  * \return      -1      failure
337  * \return      0       success
338  */
339 int sk_validate_config(const struct sk_keyfile_config *config)
340 {
341         int i;
342
343         if (!config) {
344                 printerr(0, "Null configuration passed\n");
345                 return -1;
346         }
347
348         if (config->skc_version != SK_CONF_VERSION) {
349                 printerr(0, "Invalid version\n");
350                 return -1;
351         }
352
353         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
354                 printerr(0, "Invalid HMAC algorithm\n");
355                 return -1;
356         }
357
358         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
359                 printerr(0, "Invalid crypt algorithm\n");
360                 return -1;
361         }
362
363         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
364                 /* Try to limit key expiration to some reasonable minimum and
365                  * also prevent values over INT_MAX because there appears
366                  * to be a type conversion issue */
367                 printerr(0, "Invalid expiration time should be between %d "
368                          "and %d\n", 60, INT_MAX);
369                 return -1;
370         }
371         if (config->skc_prime_bits % 8 != 0 ||
372             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
373                 printerr(0, "Invalid session key length must be a multiple of 8"
374                          " and less then %d bits\n",
375                          SK_MAX_P_BYTES * 8);
376                 return -1;
377         }
378         if (config->skc_shared_keylen % 8 != 0 ||
379             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
380                 printerr(0, "Invalid shared key max length must be a multiple "
381                          "of 8 and less then %d bits\n",
382                          SK_MAX_KEYLEN_BYTES * 8);
383                 return -1;
384         }
385
386         /* Check for terminating nulls on strings */
387         for (i = 0; i < sizeof(config->skc_fsname) &&
388              config->skc_fsname[i] != '\0';  i++)
389                 ; /* empty loop */
390         if (i == sizeof(config->skc_fsname)) {
391                 printerr(0, "File system name not null terminated\n");
392                 return -1;
393         }
394
395         for (i = 0; i < sizeof(config->skc_nodemap) &&
396              config->skc_nodemap[i] != '\0';  i++)
397                 ; /* empty loop */
398         if (i == sizeof(config->skc_nodemap)) {
399                 printerr(0, "Nodemap name not null terminated\n");
400                 return -1;
401         }
402
403         if (config->skc_type == SK_TYPE_INVALID) {
404                 printerr(0, "Invalid key type\n");
405                 return -1;
406         }
407
408         return 0;
409 }
410
411 /**
412  * Hashes \a string and places the hash in \a hash
413  * at \a hash
414  *
415  * \param[in]           string          Null terminated string to hash
416  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
417  * \param[in,out]       hash            gss_buffer_desc to hold the result
418  *
419  * \return      -1      failure
420  * \return      0       success
421  */
422 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
423                           gss_buffer_desc *hash)
424 {
425         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
426         size_t len = strlen(string);
427         unsigned int hashlen;
428
429         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
430                 goto out_err;
431         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
432                 goto out_err;
433         if (!EVP_DigestUpdate(ctx, string, len))
434                 goto out_err;
435         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
436                 goto out_err;
437
438         EVP_MD_CTX_destroy(ctx);
439         hash->length = hashlen;
440         return 0;
441
442 out_err:
443         EVP_MD_CTX_destroy(ctx);
444         return -1;
445 }
446
447 /**
448  * Hashes \a string and verifies the resulting hash matches the value
449  * in \a current_hash
450  *
451  * \param[in]           string          Null terminated string to hash
452  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
453  * \param[in,out]       current_hash    gss_buffer_desc to compare to
454  *
455  * \return      gss error       failure
456  * \return      GSS_S_COMPLETE  success
457  */
458 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
459                         const gss_buffer_desc *current_hash)
460 {
461         gss_buffer_desc hash;
462         unsigned char hashbuf[EVP_MAX_MD_SIZE];
463
464         hash.value = hashbuf;
465         hash.length = sizeof(hashbuf);
466
467         if (sk_hash_string(string, hash_alg, &hash))
468                 return GSS_S_FAILURE;
469         if (current_hash->length != hash.length)
470                 return GSS_S_DEFECTIVE_TOKEN;
471         if (memcmp(current_hash->value, hash.value, hash.length))
472                 return GSS_S_BAD_SIG;
473
474         return GSS_S_COMPLETE;
475 }
476
477 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
478                                        const char *mgsnid)
479 {
480         lnet_nid_t nid;
481         int i;
482
483         nid = libcfs_str2nid(mgsnid);
484         if (nid == LNET_NID_ANY)
485                 return 0;
486
487         for (i = 0; i < MAX_MGSNIDS; i++)
488                 if  (config->skc_mgsnids[i] == nid)
489                         return 1;
490         return 0;
491 }
492
493 /**
494  * Create an sk_cred structure populated with initial configuration info and the
495  * key.  \a tgt and \a nodemap are used in determining the expected key
496  * description so the key can be found by searching the keyring.
497  * This is done because there is no easy way to pass keys from the mount command
498  * all the way to the request_key call.  In addition any keys can be dynamically
499  * added to the keyrings and still found.  The keyring that needs to be used
500  * must be the session keyring.
501  *
502  * \param[in]   tgt             Target file system
503  * \param[in]   nodemap         Cluster name for the key.  This correlates to
504  *                              the nodemap name and is used by the server side.
505  *                              For the client this will be NULL.
506  * \param[in]   flags           Flags for the credentials
507  *
508  * \return      sk_cred Allocated struct sk_cred on success
509  * \return      NULL    failure
510  */
511 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
512                                const uint32_t flags)
513 {
514         struct sk_keyfile_config *config;
515         struct sk_kernel_ctx *kctx;
516         struct sk_cred *skc = NULL;
517         char description[SK_DESCRIPTION_SIZE + 1];
518         char fsname[MTI_NAME_MAXLEN + 1];
519         const char *mgsnid = NULL;
520         char *ptr;
521         long sk_key;
522         int keylen;
523         int len;
524         int rc;
525
526         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
527                  tgt, nodemap);
528
529         memset(description, 0, sizeof(description));
530         memset(fsname, 0, sizeof(fsname));
531
532         /* extract the file system name from target */
533         ptr = index(tgt, '-');
534         if (!ptr) {
535                 len = strlen(tgt);
536
537                 /* This must be an MGC target */
538                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
539                         printerr(0, "Invalid target name\n");
540                         return NULL;
541                 }
542                 mgsnid = tgt + 3;
543         } else {
544                 len = ptr - tgt;
545         }
546
547         if (len > MTI_NAME_MAXLEN) {
548                 printerr(0, "Invalid target name\n");
549                 return NULL;
550         }
551         memcpy(fsname, tgt, len);
552
553         if (nodemap) {
554                 if (mgsnid)
555                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
556                                       "lustre:MGS:%s", nodemap);
557                 else
558                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
559                                       "lustre:%s:%s", fsname, nodemap);
560         } else {
561                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
562                               fsname);
563         }
564
565         if (rc >= SK_DESCRIPTION_SIZE) {
566                 printerr(0, "Invalid key description\n");
567                 return NULL;
568         }
569
570         /* It may be a good idea to move Lustre keys to the gss_keyring
571          * (lgssc) type so that they expire when Lustre modules are removed.
572          * Unfortunately it can't be done at mount time because the mount
573          * syscall could trigger the Lustre modules to load and until that
574          * point we don't have a lgssc key type.
575          *
576          * TODO: Query the community for a consensus here  */
577         printerr(2, "Searching for key with description: %s\n", description);
578         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
579                                description, 0);
580         if (sk_key == -1) {
581                 printerr(1, "No key found for %s\n", description);
582                 return NULL;
583         }
584
585         keylen = keyctl_read_alloc(sk_key, (void **)&config);
586         if (keylen == -1) {
587                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
588                          strerror(errno));
589                 return NULL;
590         } else if (keylen != sizeof(*config)) {
591                 printerr(0, "Unexpected key size: %d returned for key %ld, "
592                          "expected %zu bytes\n",
593                          keylen, sk_key, sizeof(*config));
594                 goto out_err;
595         }
596
597         sk_config_disk_to_cpu(config);
598
599         if (sk_validate_config(config)) {
600                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
601                 goto out_err;
602         }
603
604         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
605                 printerr(0, "Target name does not match key's MGS NIDs\n");
606                 goto out_err;
607         }
608
609         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
610                 printerr(0, "Target name does not match key's file system\n");
611                 goto out_err;
612         }
613
614         skc = malloc(sizeof(*skc));
615         if (!skc) {
616                 printerr(0, "Failed to allocate memory for sk_cred\n");
617                 goto out_err;
618         }
619
620         /* this initializes all gss_buffer_desc to empty as well */
621         memset(skc, 0, sizeof(*skc));
622
623         skc->sc_flags = flags;
624         skc->sc_tgt.length = strlen(tgt) + 1;
625         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
626         if (!skc->sc_tgt.value) {
627                 printerr(0, "Failed to allocate memory for target\n");
628                 goto out_err;
629         }
630         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
631
632         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
633         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
634         if (!skc->sc_nodemap_hash.value) {
635                 printerr(0, "Failed to allocate memory for nodemap hash\n");
636                 goto out_err;
637         }
638
639         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
640                            &skc->sc_nodemap_hash)) {
641                 printerr(0, "Failed to generate hash for nodemap name\n");
642                 goto out_err;
643         }
644
645         kctx = &skc->sc_kctx;
646         kctx->skc_version = config->skc_version;
647         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
648         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
649         kctx->skc_expire = config->skc_expire;
650
651         /* key payload format is in bits, convert to bytes */
652         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
653         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
654         if (!kctx->skc_shared_key.value) {
655                 printerr(0, "Failed to allocate memory for shared key\n");
656                 goto out_err;
657         }
658         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
659                kctx->skc_shared_key.length);
660
661         skc->sc_p.length = config->skc_prime_bits / 8;
662         skc->sc_p.value = malloc(skc->sc_p.length);
663         if (!skc->sc_p.value) {
664                 printerr(0, "Failed to allocate p\n");
665                 goto out_err;
666         }
667         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
668
669         free(config);
670
671         return skc;
672
673 out_err:
674         sk_free_cred(skc);
675
676         free(config);
677         return NULL;
678 }
679
680 #define SK_GENERATOR 2
681 #define DH_NUMBER_ITERATIONS_FOR_PRIME 64
682
683 /* OpenSSL 1.1.1c increased the number of rounds used for Miller-Rabin testing
684  * of the prime provided as input parameter to DH_check(). This makes the check
685  * roughly x10 longer, and causes request timeouts when an SSK flavor is being
686  * used.
687  *
688  * Instead, use a dynamic number Miller-Rabin rounds based on the speed of the
689  * check on the current system, evaluated when the lsvcgssd daemon starts, but
690  * at least as many as OpenSSL 1.1.1b used for the same key size. If default
691  * DH_check() duration is OK, use it directly instead of limiting the rounds.
692  *
693  * If \a num_rounds == 0, we just call original DH_check() directly.
694  */
695 static bool sk_is_dh_valid(const DH *dh, int num_rounds)
696 {
697         const BIGNUM *p, *g;
698         BN_ULONG word;
699         BN_CTX *ctx;
700         BIGNUM *r;
701         bool valid = false;
702         int rc;
703
704         if (num_rounds == 0) {
705                 int codes = 0;
706
707                 rc = DH_check(dh, &codes);
708                 if (rc != 1 || codes) {
709                         printerr(0, "DH_check(0) failed: codes=%#x: rc=%d\n",
710                                  codes, rc);
711                         return false;
712                 }
713                 return true;
714         }
715
716         DH_get0_pqg(dh, &p, NULL, &g);
717
718         if (!BN_is_word(g, SK_GENERATOR)) {
719                 printerr(0, "%s: Diffie-Hellman generator is not %u\n",
720                          program_invocation_short_name, SK_GENERATOR);
721                 return false;
722         }
723
724         word = BN_mod_word(p, 24);
725         if (word != 11) {
726                 printerr(0, "%s: Diffie-Hellman prime modulo=%lu unsuitable\n",
727                          program_invocation_short_name, word);
728                 return false;
729         }
730
731         ctx = BN_CTX_new();
732         if (ctx == NULL) {
733                 printerr(0, "%s: Diffie-Hellman error allocating context\n",
734                          program_invocation_short_name);
735                 return false;
736         }
737         BN_CTX_start(ctx);
738         r = BN_CTX_get(ctx); /* must be called before "ctx" used elsewhere */
739
740         rc = BN_is_prime_ex(p, num_rounds, ctx, NULL);
741         if (rc == 0)
742                 printerr(0, "%s: Diffie-Hellman 'p' not prime in %u rounds\n",
743                          program_invocation_short_name, num_rounds);
744         if (rc <= 0)
745                 goto out_free;
746
747         if (!BN_rshift1(r, p)) {
748                 printerr(0, "%s: error shifting BigNum 'r' by 'p'\n",
749                          program_invocation_short_name);
750                 goto out_free;
751         }
752         rc = BN_is_prime_ex(r, num_rounds, ctx, NULL);
753         if (rc == 0)
754                 printerr(0, "%s: Diffie-Hellman 'r' not prime in %u rounds\n",
755                          program_invocation_short_name, num_rounds);
756         if (rc <= 0)
757                 goto out_free;
758
759         valid = true;
760
761 out_free:
762         BN_CTX_end(ctx);
763         BN_CTX_free(ctx);
764
765         return valid;
766 }
767
768 #define VALUE_LENGTH 256
769 static unsigned char test_prime[VALUE_LENGTH] =
770         "\xf7\xfa\x49\xd8\xec\xb1\x3b\xff\x26\x10\x3f\xc5\x3a\xc5\xcc\x40"
771         "\x4f\xbf\x92\xe1\x8b\x83\xe7\xa2\xba\x0f\x51\x5a\x91\x48\xe0\xa3"
772         "\xf1\x4d\xbc\xbb\x8a\x28\x14\xac\x02\x23\x76\x42\x17\x4d\x3c\xdc"
773         "\x5e\x4f\x80\x1f\xd7\x54\x1c\x50\xac\x3b\x28\x68\x8d\x71\x41\x7f"
774         "\xa7\x1c\x2f\x22\xd3\xa8\x91\xb2\x64\xb6\x84\xa6\xcf\x06\x16\x91"
775         "\x2f\xb8\xb4\x42\x1d\x3a\x4e\x3a\x0c\x7f\x04\x69\x78\xb5\x8f\x92"
776         "\x07\x89\xac\x24\x06\x53\x2c\x23\xec\xaa\x5c\xb4\x7b\x49\xbc\xf4"
777         "\x90\x67\x71\x9c\x24\x2c\x1d\x8d\x76\xc8\x85\x4e\x19\xf1\xf9\x33"
778         "\x45\xbd\x9f\x7d\x0a\x08\x8c\x22\xcc\x35\xf3\x5b\xab\x3f\x24\x9d"
779         "\x61\x70\x86\xbb\xbe\xd8\xb0\xf8\x34\xfa\xeb\x5b\x8e\xf2\x62\x23"
780         "\xd1\xfb\xbb\xb8\x21\x71\x1e\x39\x39\x59\xe0\x82\x98\x41\x84\x40"
781         "\x1f\xd3\x9b\xa3\x73\xdb\xec\xe0\xc0\xde\x2d\x1c\xea\x43\x40\x93"
782         "\x98\x38\x03\x36\x1e\xe1\xe7\x39\x7b\x35\x92\x4a\x51\xa5\x91\x63"
783         "\xd5\x31\x98\x3d\x89\x27\x6f\xcc\x69\xff\xbe\x31\x13\xdc\x2f\x72"
784         "\x2d\xab\x6a\xb7\x13\xd3\x47\xda\xaa\xf3\x3c\xa0\xfd\xaa\x0f\x02"
785         "\x96\x81\x1a\x26\xe8\xf7\x25\x65\x33\x78\xd9\x6b\x6d\xb0\xd9\xfb";
786
787 /**
788  * Measure time taken by prime testing routine for a 2048 bit long prime,
789  * depending on the number of check rounds.
790  *
791  * \param[in]   usec_check_max    max time allowed for DH_check completion
792  *
793  * \retval      max number of rounds to keep prime testing under usec_check_max
794  *              return 0 if we should use the default DH_check rounds
795  */
796 int sk_speedtest_dh_valid(unsigned int usec_check_max)
797 {
798         DH *dh;
799         BIGNUM *p, *g;
800         int num_rounds, prev_rounds = 0;
801
802         dh = DH_new();
803         if (!dh)
804                 return 0;
805
806         p = BN_bin2bn(test_prime, VALUE_LENGTH, NULL);
807         if (!p)
808                 goto free_dh;
809
810         g = BN_new();
811         if (!g)
812                 goto free_p;
813
814         if (!BN_set_word(g, SK_GENERATOR))
815                 goto free_g;
816
817         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
818         if (!DH_set0_pqg(dh, p, NULL, g)) {
819         free_g:
820                 BN_free(g);
821         free_p:
822                 BN_free(p);
823                 goto free_dh;
824         }
825
826         for (num_rounds = 0;
827              num_rounds <= DH_NUMBER_ITERATIONS_FOR_PRIME;
828              num_rounds += (num_rounds <= 4 ? 4 : 8)) {
829                 unsigned int usec_this;
830                 int j;
831
832                 /* get max duration of 4 runs at current number of rounds */
833                 usec_this = 0;
834                 for (j = 0; j < 4; j++) {
835                         struct timeval now, prev;
836                         unsigned int usec_curr;
837
838                         gettimeofday(&prev, NULL);
839                         if (!sk_is_dh_valid(dh, num_rounds)) {
840                                 /* if test_prime is found bad, use default */
841                                 prev_rounds = 0;
842                                 goto free_dh;
843                         }
844                         gettimeofday(&now, NULL);
845                         usec_curr = (now.tv_sec - prev.tv_sec) * 1000000 +
846                                     now.tv_usec - prev.tv_usec;
847                         if (usec_curr > usec_this)
848                                 usec_this = usec_curr;
849                 }
850                 printerr(2, "%s: %d rounds: %d usec\n",
851                          program_invocation_short_name, num_rounds, usec_this);
852                 if (num_rounds == 0) {
853                         if (usec_this <= usec_check_max)
854                         /* using original check rounds as implemented in
855                          * DH_check() took less time than the max allowed,
856                          * so just use original DH_check()
857                          */
858                                 break;
859                 } else if (usec_this >= usec_check_max) {
860                         break;
861                 }
862                 prev_rounds = num_rounds;
863         }
864
865 free_dh:
866         DH_free(dh);
867
868         return prev_rounds;
869 }
870
871 /**
872  * Populates the DH parameters for the DHKE
873  *
874  * \param[in,out]       skc             Shared key credentials structure to
875  *                                      populate with DH parameters
876  *
877  * \retval      GSS_S_COMPLETE  success
878  * \retval      GSS_S_FAILURE   failure
879  */
880 uint32_t sk_gen_params(struct sk_cred *skc, int num_rounds)
881 {
882         uint32_t random;
883         BIGNUM *p, *g;
884         const BIGNUM *pub_key;
885
886         /* Random value used by both the request and response as part of the
887          * key binding material.  This also should ensure we have unqiue
888          * tokens that are sent to the remote server which is important because
889          * the token is hashed for the sunrpc cache lookups and a failure there
890          * would cause connection attempts to fail indefinitely due to the large
891          * timeout value on the server side */
892         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
893                 printerr(0, "Failed to get data for random parameter: %s\n",
894                          ERR_error_string(ERR_get_error(), NULL));
895                 return GSS_S_FAILURE;
896         }
897
898         /* The random value will always be used in byte range operations
899          * so we keep it as big endian from this point on */
900         skc->sc_kctx.skc_host_random = random;
901
902         /* Populate DH parameters */
903         skc->sc_params = DH_new();
904         if (!skc->sc_params) {
905                 printerr(0, "Failed to allocate DH\n");
906                 return GSS_S_FAILURE;
907         }
908
909         p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
910         if (!p) {
911                 printerr(0, "Failed to convert binary to BIGNUM\n");
912                 return GSS_S_FAILURE;
913         }
914
915         /* We use a static generator for shared key */
916         g = BN_new();
917         if (!g) {
918                 printerr(0, "Failed to allocate new BIGNUM\n");
919                 return GSS_S_FAILURE;
920         }
921         if (BN_set_word(g, SK_GENERATOR) != 1) {
922                 printerr(0, "Failed to set g value for DH params\n");
923                 return GSS_S_FAILURE;
924         }
925
926         if (!DH_set0_pqg(skc->sc_params, p, NULL, g)) {
927                 printerr(0, "Failed to set pqg\n");
928                 return GSS_S_FAILURE;
929         }
930
931         /* Verify that we have a safe prime and valid generator */
932         if (!sk_is_dh_valid(skc->sc_params, num_rounds))
933                 return GSS_S_FAILURE;
934
935         if (DH_generate_key(skc->sc_params) != 1) {
936                 printerr(0, "Failed to generate public DH key: %s\n",
937                          ERR_error_string(ERR_get_error(), NULL));
938                 return GSS_S_FAILURE;
939         }
940
941         DH_get0_key(skc->sc_params, &pub_key, NULL);
942         skc->sc_pub_key.length = BN_num_bytes(pub_key);
943         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
944         if (!skc->sc_pub_key.value) {
945                 printerr(0, "Failed to allocate memory for public key\n");
946                 return GSS_S_FAILURE;
947         }
948
949         BN_bn2bin(pub_key, skc->sc_pub_key.value);
950
951         return GSS_S_COMPLETE;
952 }
953
954 /**
955  * Convert SK hash algorithm into openssl message digest
956  *
957  * \param[in,out]       alg             SK hash algorithm
958  *
959  * \retval              EVP_MD
960  */
961 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
962 {
963         switch (alg) {
964         case CFS_HASH_ALG_SHA256:
965                 return EVP_sha256();
966         case CFS_HASH_ALG_SHA512:
967                 return EVP_sha512();
968         default:
969                 return EVP_md_null();
970         }
971 }
972
973 /**
974  * Signs (via HMAC) the parameters used only in the key initialization protocol.
975  *
976  * \param[in]           key             Key to use for HMAC
977  * \param[in]           bufs            Array of gss_buffer_desc to generate
978  *                                      HMAC for
979  * \param[in]           numbufs         Number of buffers in array
980  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
981  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
982  *                                      in this gss_buffer_desc.  Caller must
983  *                                      free this.
984  *
985  * \retval      0       success
986  * \retval      -1      failure
987  */
988 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
989                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
990 {
991         HMAC_CTX *hctx;
992         unsigned int hashlen = EVP_MD_size(hash_alg);
993         int i;
994         int rc = -1;
995
996         if (hash_alg == EVP_md_null()) {
997                 printerr(0, "Invalid hash algorithm\n");
998                 return -1;
999         }
1000
1001         hctx = HMAC_CTX_new();
1002
1003         hmac->length = hashlen;
1004         hmac->value = malloc(hashlen);
1005         if (!hmac->value) {
1006                 printerr(0, "Failed to allocate memory for HMAC\n");
1007                 goto out;
1008         }
1009
1010         if (HMAC_Init_ex(hctx, key->value, key->length, hash_alg, NULL) != 1) {
1011                 printerr(0, "Failed to init HMAC\n");
1012                 goto out;
1013         }
1014
1015         for (i = 0; i < numbufs; i++) {
1016                 if (HMAC_Update(hctx, bufs[i].value, bufs[i].length) != 1) {
1017                         printerr(0, "Failed to update HMAC\n");
1018                         goto out;
1019                 }
1020         }
1021
1022         /* The result gets populated in hmac */
1023         if (HMAC_Final(hctx, hmac->value, &hashlen) != 1) {
1024                 printerr(0, "Failed to finalize HMAC\n");
1025                 goto out;
1026         }
1027
1028         if (hmac->length != hashlen) {
1029                 printerr(0, "HMAC size does not match expected\n");
1030                 goto out;
1031         }
1032
1033         rc = 0;
1034 out:
1035         HMAC_CTX_free(hctx);
1036         return rc;
1037 }
1038
1039 /**
1040  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
1041  * and verifies against \a hmac.
1042  *
1043  * \param[in]   skc             Shared key credentials
1044  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
1045  * \param[in]   numbufs         Number of buffers in array
1046  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
1047  * \param[in]   hmac            HMAC to verify against
1048  *
1049  * \retval      GSS_S_COMPLETE  success (match)
1050  * \retval      gss error       failure
1051  */
1052 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
1053                         const int numbufs, const EVP_MD *hash_alg,
1054                         gss_buffer_desc *hmac)
1055 {
1056         gss_buffer_desc bufs_hmac;
1057         int rc;
1058
1059         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
1060                          &bufs_hmac)) {
1061                 printerr(0, "Failed to sign buffers to verify HMAC\n");
1062                 if (bufs_hmac.value)
1063                         free(bufs_hmac.value);
1064                 return GSS_S_FAILURE;
1065         }
1066
1067         if (hmac->length != bufs_hmac.length) {
1068                 printerr(0, "Invalid HMAC size\n");
1069                 free(bufs_hmac.value);
1070                 return GSS_S_BAD_SIG;
1071         }
1072
1073         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
1074         free(bufs_hmac.value);
1075
1076         if (rc)
1077                 return GSS_S_BAD_SIG;
1078
1079         return GSS_S_COMPLETE;
1080 }
1081
1082 /**
1083  * Cleanup an sk_cred freeing any resources
1084  *
1085  * \param[in,out]       skc     Shared key credentials to free
1086  */
1087 void sk_free_cred(struct sk_cred *skc)
1088 {
1089         if (!skc)
1090                 return;
1091
1092         if (skc->sc_p.value)
1093                 free(skc->sc_p.value);
1094         if (skc->sc_pub_key.value)
1095                 free(skc->sc_pub_key.value);
1096         if (skc->sc_tgt.value)
1097                 free(skc->sc_tgt.value);
1098         if (skc->sc_nodemap_hash.value)
1099                 free(skc->sc_nodemap_hash.value);
1100         if (skc->sc_hmac.value)
1101                 free(skc->sc_hmac.value);
1102
1103         /* Overwrite keys and IV before freeing */
1104         if (skc->sc_dh_shared_key.value) {
1105                 memset(skc->sc_dh_shared_key.value, 0,
1106                        skc->sc_dh_shared_key.length);
1107                 free(skc->sc_dh_shared_key.value);
1108         }
1109         if (skc->sc_kctx.skc_hmac_key.value) {
1110                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
1111                        skc->sc_kctx.skc_hmac_key.length);
1112                 free(skc->sc_kctx.skc_hmac_key.value);
1113         }
1114         if (skc->sc_kctx.skc_encrypt_key.value) {
1115                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
1116                        skc->sc_kctx.skc_encrypt_key.length);
1117                 free(skc->sc_kctx.skc_encrypt_key.value);
1118         }
1119         if (skc->sc_kctx.skc_shared_key.value) {
1120                 memset(skc->sc_kctx.skc_shared_key.value, 0,
1121                        skc->sc_kctx.skc_shared_key.length);
1122                 free(skc->sc_kctx.skc_shared_key.value);
1123         }
1124         if (skc->sc_kctx.skc_session_key.value) {
1125                 memset(skc->sc_kctx.skc_session_key.value, 0,
1126                        skc->sc_kctx.skc_session_key.length);
1127                 free(skc->sc_kctx.skc_session_key.value);
1128         }
1129
1130         if (skc->sc_params)
1131                 DH_free(skc->sc_params);
1132
1133         free(skc);
1134         skc = NULL;
1135 }
1136
1137 /* This function handles key derivation using the hash algorithm specified in
1138  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
1139  * \a origin_key to produce a \a derived_key.  The first element of the
1140  * key_binding_bufs array is reserved for the counter used in the KDF.  The
1141  * derived key in \a derived_key could differ in size from \a origin_key and
1142  * must be populated with the expected size and a valid buffer to hold the
1143  * contents.
1144  *
1145  * If the derived key size is greater than the HMAC algorithm size it will be
1146  * a done using several iterations of a counter and the key binding bufs.
1147  *
1148  * If the size is smaller it will take copy the first N bytes necessary to
1149  * fill the derived key. */
1150 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
1151            gss_buffer_desc *key_binding_bufs, int numbufs,
1152            enum cfs_crypto_hash_alg hmac_alg)
1153 {
1154         size_t remain;
1155         size_t bytes;
1156         uint32_t counter;
1157         char *keydata;
1158         gss_buffer_desc tmp_hash;
1159         int i;
1160         int rc;
1161
1162         if (numbufs < 1)
1163                 return -EINVAL;
1164
1165         /* Use a counter as the first buffer followed by the key binding
1166          * buffers in the event we need more than one a single cycle to
1167          * produced a symmetric key large enough in size */
1168         key_binding_bufs[0].value = &counter;
1169         key_binding_bufs[0].length = sizeof(counter);
1170
1171         remain = derived_key->length;
1172         keydata = derived_key->value;
1173         i = 0;
1174         while (remain > 0) {
1175                 counter = htobe32(i++);
1176                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1177                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1178                 if (rc) {
1179                         if (tmp_hash.value)
1180                                 free(tmp_hash.value);
1181                         return rc;
1182                 }
1183
1184                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
1185                         free(tmp_hash.value);
1186                         return -EINVAL;
1187                 }
1188
1189                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1190                 memcpy(keydata, tmp_hash.value, bytes);
1191                 free(tmp_hash.value);
1192                 remain -= bytes;
1193                 keydata += bytes;
1194         }
1195
1196         return 0;
1197 }
1198
1199 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1200  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1201  * (Sep 2014) Section 5.5.1
1202  *
1203  * \param[in,out]       skc             Shared key credentials structure with
1204  *
1205  * \return      -1              failure
1206  * \return      0               success
1207  */
1208 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1209                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1210 {
1211         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1212         gss_buffer_desc *session_key = &kctx->skc_session_key;
1213         gss_buffer_desc bufs[5];
1214         enum cfs_crypto_crypt_alg crypt_alg;
1215         int rc = -1;
1216
1217         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1218         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1219         session_key->value = malloc(session_key->length);
1220         if (!session_key->value) {
1221                 printerr(0, "Failed to allocate memory for session key\n");
1222                 return rc;
1223         }
1224
1225         /* Key binding info ordering
1226          * 1. Reserved for counter
1227          * 1. DH shared key
1228          * 2. Client's NIDs
1229          * 3. Client's token
1230          * 4. Server's token */
1231         bufs[0].value = NULL;
1232         bufs[0].length = 0;
1233         bufs[1] = skc->sc_dh_shared_key;
1234         bufs[2].value = &client_nid;
1235         bufs[2].length = sizeof(client_nid);
1236         bufs[3] = *client_token;
1237         bufs[4] = *server_token;
1238
1239         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1240                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1241 }
1242
1243 /* Uses the session key to create an HMAC key and encryption key.  In
1244  * integrity mode the session key used to generate the HMAC key uses
1245  * session information which is available on the wire but by creating
1246  * a session based HMAC key we can prevent potential replay as both the
1247  * client and server have random numbers used as part of the key creation.
1248  *
1249  * The keys used for integrity and privacy are formulated as below using
1250  * the session key that is the output of the key derivation function.  The
1251  * HMAC algorithm is determined by the shared key algorithm selected in the
1252  * key file.
1253  *
1254  * For ski mode:
1255  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1256  *
1257  * For skpi mode:
1258  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1259  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1260  *
1261  * \param[in,out]       skc             Shared key credentials structure with
1262  *
1263  * \return      -1              failure
1264  * \return      0               success
1265  */
1266 int sk_compute_keys(struct sk_cred *skc)
1267 {
1268         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1269         gss_buffer_desc *session_key = &kctx->skc_session_key;
1270         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1271         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1272         enum cfs_crypto_hash_alg hmac_alg;
1273         enum cfs_crypto_crypt_alg crypt_alg;
1274         char *encrypt = "Encrypt";
1275         char *integrity = "Integrity";
1276         int rc;
1277
1278         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1279         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1280         hmac_key->value = malloc(hmac_key->length);
1281         if (!hmac_key->value)
1282                 return -ENOMEM;
1283
1284         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1285                                session_key->length, SK_PBKDF2_ITERATIONS,
1286                                sk_hash_to_evp_md(hmac_alg),
1287                                hmac_key->length, hmac_key->value);
1288         if (rc == 0)
1289                 return -EINVAL;
1290
1291         /* Encryption key is only populated in privacy mode */
1292         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1293                 return 0;
1294
1295         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1296         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1297         encrypt_key->value = malloc(encrypt_key->length);
1298         if (!encrypt_key->value)
1299                 return -ENOMEM;
1300
1301         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1302                                session_key->length, SK_PBKDF2_ITERATIONS,
1303                                sk_hash_to_evp_md(hmac_alg),
1304                                encrypt_key->length, encrypt_key->value);
1305         if (rc == 0)
1306                 return -EINVAL;
1307
1308         return 0;
1309 }
1310
1311 /**
1312  * Computes a session key based on the DH parameters from the host and its peer
1313  *
1314  * \param[in,out]       skc             Shared key credentials structure with
1315  *                                      the session key populated with the
1316  *                                      compute key
1317  * \param[in]           pub_key         Public key returned from peer in
1318  *                                      gss_buffer_desc
1319  * \return      gss error               failure
1320  * \return      GSS_S_COMPLETE          success
1321  */
1322 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1323 {
1324         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1325         BIGNUM *remote_pub_key;
1326         int status;
1327         uint32_t rc = GSS_S_FAILURE;
1328
1329         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1330         if (!remote_pub_key) {
1331                 printerr(0, "Failed to convert binary to BIGNUM\n");
1332                 return rc;
1333         }
1334
1335         dh_shared->length = DH_size(skc->sc_params);
1336         dh_shared->value = malloc(dh_shared->length);
1337         if (!dh_shared->value) {
1338                 printerr(0, "Failed to allocate memory for computed shared "
1339                          "secret key\n");
1340                 goto out_err;
1341         }
1342
1343         /* This compute the shared key from the DHKE */
1344         status = DH_compute_key(dh_shared->value, remote_pub_key,
1345                                 skc->sc_params);
1346         if (status == -1) {
1347                 printerr(0, "DH_compute_key() failed: %s\n",
1348                          ERR_error_string(ERR_get_error(), NULL));
1349                 goto out_err;
1350         } else if (status < dh_shared->length) {
1351                 /* there is around 1 chance out of 256 that the returned
1352                  * shared key is shorter than expected
1353                  */
1354                 if (status >= dh_shared->length - 2) {
1355                         int shift = dh_shared->length - status;
1356                         /* if the key is short by only 1 or 2 bytes, just
1357                          * prepend it with 0s
1358                          */
1359                         memmove((void *)(dh_shared->value + shift),
1360                                 dh_shared->value, status);
1361                         memset(dh_shared->value, 0, shift);
1362                 } else {
1363                         /* if the key is really too short, return GSS_S_BAD_QOP
1364                          * so that the caller can retry to generate
1365                          */
1366                         printerr(0, "DH_compute_key() returned a short key of %d bytes, expected: %zu\n",
1367                                  status, dh_shared->length);
1368                         rc = GSS_S_BAD_QOP;
1369                         goto out_err;
1370                 }
1371         }
1372
1373         rc = GSS_S_COMPLETE;
1374
1375 out_err:
1376         BN_free(remote_pub_key);
1377         return rc;
1378 }
1379
1380 /**
1381  * Creates a serialized buffer for the kernel in the order of struct
1382  * sk_kernel_ctx.
1383  *
1384  * \param[in,out]       skc             Shared key credentials structure
1385  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1386  *                                      Caller must free this buffer.
1387  *
1388  * \return      0       success
1389  * \return      -1      failure
1390  */
1391 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1392 {
1393         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1394         char *p, *end;
1395         size_t bufsize;
1396
1397         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1398                   kctx->skc_encrypt_key.length;
1399
1400         ctx_token->value = malloc(bufsize);
1401         if (!ctx_token->value)
1402                 return -1;
1403         ctx_token->length = bufsize;
1404
1405         p = ctx_token->value;
1406         end = p + ctx_token->length;
1407
1408         if (WRITE_BYTES(&p, end, kctx->skc_version))
1409                 return -1;
1410         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1411                 return -1;
1412         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1413                 return -1;
1414         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1415                 return -1;
1416         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1417                 return -1;
1418         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1419                 return -1;
1420         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1421                 return -1;
1422         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1423                 return -1;
1424
1425         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1426
1427         return 0;
1428 }
1429
1430 /**
1431  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1432  * up to \a numbufs.  Memory is allocated for each value and length
1433  * will be populated with the length
1434  *
1435  * \param[in,out]       bufs    Array of gss_buffer_descs
1436  * \param[in,out]       numbufs number of gss_buffer_desc in array
1437  * \param[in]           ns      netstring to decode
1438  *
1439  * \return      buffers populated       success
1440  * \return      -1                      failure
1441  */
1442 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1443 {
1444         char *ptr = ns->value;
1445         size_t remain = ns->length;
1446         unsigned int size;
1447         int digits;
1448         int sep;
1449         int rc;
1450         int i;
1451
1452         for (i = 0; i < numbufs; i++) {
1453                 /* read the size of first buffer */
1454                 rc = sscanf(ptr, "%9u", &size);
1455                 if (rc < 1)
1456                         goto out_err;
1457                 digits = (size) ? ceil(log10(size + 1)) : 1;
1458
1459                 /* sep of current string */
1460                 sep = size + digits + 2;
1461
1462                 /* check to make sure it's valid */
1463                 if (remain < sep || ptr[digits] != ':' ||
1464                     ptr[sep - 1] != ',')
1465                         goto out_err;
1466
1467                 bufs[i].length = size;
1468                 if (size == 0) {
1469                         bufs[i].value = NULL;
1470                 } else {
1471                         bufs[i].value = malloc(size);
1472                         if (!bufs[i].value)
1473                                 goto out_err;
1474                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1475                 }
1476
1477                 remain -= sep;
1478                 ptr += sep;
1479         }
1480
1481         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1482         return i;
1483
1484 out_err:
1485         while (i-- > 0) {
1486                 if (bufs[i].value)
1487                         free(bufs[i].value);
1488                 bufs[i].length = 0;
1489         }
1490         return -1;
1491 }
1492
1493 /**
1494  * Creates a netstring in a gss_buffer_desc that consists of all
1495  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1496  * as binary as it can contain null characters.
1497  *
1498  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1499  * \param[in]           numbufs         Number of buffers in array
1500  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1501  *                                      netstring
1502  *
1503  * \return      -1      failure
1504  * \return      0       success
1505  */
1506 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1507                         gss_buffer_desc *ns)
1508 {
1509         unsigned char *ptr;
1510         int size = 0;
1511         int rc;
1512         int i;
1513
1514         /* size of string in decimal, string size, colon, and comma */
1515         for (i = 0; i < numbufs; i++) {
1516
1517                 if (bufs[i].length == 0)
1518                         size += 3;
1519                 else
1520                         size += ceil(log10(bufs[i].length + 1)) +
1521                                 bufs[i].length + 2;
1522         }
1523
1524         ns->length = size;
1525         ns->value = malloc(ns->length);
1526         if (!ns->value) {
1527                 ns->length = 0;
1528                 return -1;
1529         }
1530
1531         ptr = ns->value;
1532         for (i = 0; i < numbufs; i++) {
1533                 /* size */
1534                 rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
1535                 ptr += rc;
1536
1537                 /* contents */
1538                 memcpy(ptr, bufs[i].value, bufs[i].length);
1539                 ptr += bufs[i].length;
1540
1541                 /* delimeter */
1542                 *ptr++ = ',';
1543
1544                 size -= bufs[i].length + rc + 1;
1545
1546                 /* should not happen */
1547                 if (size < 0)
1548                         abort();
1549         }
1550
1551         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1552         return 0;
1553 }