Whamcloud - gitweb
LU-12511 utils: fix regression for UAPI headers for native client
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 /* We need to use some deprecated APIs */
37 #define OPENSSL_SUPPRESS_DEPRECATED
38 #include <openssl/dh.h>
39 #include <openssl/engine.h>
40 #include <openssl/err.h>
41 #include <openssl/hmac.h>
42 #include <sys/types.h>
43 #include <sys/stat.h>
44 #include <libcfs/util/string.h>
45 #include <sys/time.h>
46
47 #include "sk_utils.h"
48 #include "write_bytes.h"
49
50 #define SK_PBKDF2_ITERATIONS 10000
51
52 #ifdef _NEW_BUILD_
53 # include "lgss_utils.h"
54 #else
55 # include "gss_util.h"
56 # include "gss_oids.h"
57 # include "err_util.h"
58 #endif
59
60 #ifdef _ERR_UTIL_H_
61 /**
62  * Initializes logging
63  * \param[in]   program         Program name to output
64  * \param[in]   verbose         Verbose flag
65  * \param[in]   fg              Whether or not to run in foreground
66  *
67  */
68 void sk_init_logging(char *program, int verbose, int fg)
69 {
70         initerr(program, verbose, fg);
71 }
72 #endif
73
74 /**
75  * Loads the key from \a filename and returns the struct sk_keyfile_config.
76  * It should be freed by the caller.
77  *
78  * \param[in]   filename                Disk or key payload data
79  *
80  * \return      sk_keyfile_config       sucess
81  * \return      NULL                    failure
82  */
83 struct sk_keyfile_config *sk_read_file(char *filename)
84 {
85         struct sk_keyfile_config *config;
86         char *ptr;
87         size_t rc;
88         size_t remain;
89         int fd;
90
91         config = malloc(sizeof(*config));
92         if (!config) {
93                 printerr(0, "Failed to allocate memory for config\n");
94                 return NULL;
95         }
96
97         /* allow standard input override */
98         if (strcmp(filename, "-") == 0)
99                 fd = STDIN_FILENO;
100         else
101                 fd = open(filename, O_RDONLY);
102
103         if (fd == -1) {
104                 printerr(0, "Error opening key file '%s': %s\n", filename,
105                          strerror(errno));
106                 goto out_free;
107         } else if (fd != STDIN_FILENO) {
108                 struct stat st;
109
110                 rc = fstat(fd, &st);
111                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
112                         fprintf(stderr, "warning: "
113                                 "secret key '%s' has insecure file mode %#o\n",
114                                 filename, st.st_mode);
115         }
116
117         ptr = (char *)config;
118         remain = sizeof(*config);
119         while (remain > 0) {
120                 rc = read(fd, ptr, remain);
121                 if (rc == -1) {
122                         if (errno == EINTR)
123                                 continue;
124                         printerr(0, "read() failed on %s: %s\n", filename,
125                                  strerror(errno));
126                         goto out_close;
127                 } else if (rc == 0) {
128                         printerr(0, "File %s does not have a complete key\n",
129                                  filename);
130                         goto out_close;
131                 }
132                 ptr += rc;
133                 remain -= rc;
134         }
135
136         if (fd != STDIN_FILENO)
137                 close(fd);
138         sk_config_disk_to_cpu(config);
139         return config;
140
141 out_close:
142         close(fd);
143 out_free:
144         free(config);
145         return NULL;
146 }
147
148 /**
149  * Checks if a key matching \a description is found in the keyring for
150  * logging purposes and then attempts to load \a payload of \a psize into a key
151  * with \a description.
152  *
153  * \param[in]   payload         Key payload
154  * \param[in]   psize           Payload size
155  * \param[in]   description     Description used for key in keyring
156  *
157  * \return      0       sucess
158  * \return      -1      failure
159  */
160 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
161                                 const char *description)
162 {
163         struct sk_keyfile_config payload;
164         key_serial_t key;
165
166         memcpy(&payload, skc, sizeof(*skc));
167
168         /* In the keyring use the disk layout so keyctl pipe can be used */
169         sk_config_cpu_to_disk(&payload);
170
171         /* Check to see if a key is already loaded matching description */
172         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
173         if (key != -1)
174                 printerr(2, "Key %d found in session keyring, replacing\n",
175                          key);
176
177         key = add_key("user", description, &payload, sizeof(payload),
178                       KEY_SPEC_USER_KEYRING);
179         if (key != -1) {
180                 key_perm_t perm = KEY_POS_ALL | KEY_USR_ALL |
181                         KEY_GRP_ALL | KEY_OTH_ALL;
182
183                 if (keyctl_setperm(key, perm) < 0)
184                         printerr(2, "Failed to set perm 0x%x on key %d\n",
185                                  perm, key);
186                 printerr(2, "Added key %d with description %s\n", key,
187                          description);
188         } else {
189                 printerr(0, "Failed to add key with %s\n", description);
190         }
191
192         return key;
193 }
194
195 /**
196  * Reads the key from \a path, verifies it and loads into the session keyring
197  * using a description determined by the the \a type.  Existing keys with the
198  * same description are replaced.
199  *
200  * \param[in]   path    Path to key file
201  * \param[in]   type    Type of key to load which determines the description
202  *
203  * \return      0       sucess
204  * \return      -1      failure
205  */
206 int sk_load_keyfile(char *path)
207 {
208         struct sk_keyfile_config *config;
209         char description[SK_DESCRIPTION_SIZE + 1];
210         struct stat buf;
211         int i;
212         int rc;
213         int rc2 = -1;
214
215         rc = stat(path, &buf);
216         if (rc == -1) {
217                 printerr(0, "stat() failed for file %s: %s\n", path,
218                          strerror(errno));
219                 return rc2;
220         }
221
222         config = sk_read_file(path);
223         if (!config)
224                 return rc2;
225
226         /* Similar to ssh, require adequate care of key files */
227         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
228                 printerr(0, "Shared key files must be read/writeable only by "
229                          "owner\n");
230                 return -1;
231         }
232
233         if (sk_validate_config(config))
234                 goto out;
235
236         /* The server side can have multiple key files per file system so
237          * the nodemap name is appended to the key description to uniquely
238          * identify it */
239         if (config->skc_type & SK_TYPE_MGS) {
240                 /* Any key can be an MGS key as long as we are told to use it */
241                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
242                               config->skc_nodemap);
243                 if (rc >= SK_DESCRIPTION_SIZE)
244                         goto out;
245                 if (sk_load_key(config, description) == -1)
246                         goto out;
247         }
248         if (config->skc_type & SK_TYPE_SERVER) {
249                 /* Server keys need to have the file system name in the key */
250                 if (!config->skc_fsname) {
251                         printerr(0, "Key configuration has no file system "
252                                  "attribute.  Can't load as server type\n");
253                         goto out;
254                 }
255                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
256                               config->skc_fsname, config->skc_nodemap);
257                 if (rc >= SK_DESCRIPTION_SIZE)
258                         goto out;
259                 if (sk_load_key(config, description) == -1)
260                         goto out;
261         }
262         if (config->skc_type & SK_TYPE_CLIENT) {
263                 /* Load client file system key */
264                 if (config->skc_fsname) {
265                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
266                                       "lustre:%s", config->skc_fsname);
267                         if (rc >= SK_DESCRIPTION_SIZE)
268                                 goto out;
269                         if (sk_load_key(config, description) == -1)
270                                 goto out;
271                 }
272
273                 /* Load client MGC keys */
274                 for (i = 0; i < MAX_MGSNIDS; i++) {
275                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
276                                 continue;
277                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
278                                       "lustre:MGC%s",
279                                       libcfs_nid2str(config->skc_mgsnids[i]));
280                         if (rc >= SK_DESCRIPTION_SIZE)
281                                 goto out;
282                         if (sk_load_key(config, description) == -1)
283                                 goto out;
284                 }
285         }
286
287         rc2 = 0;
288
289 out:
290         free(config);
291         return rc2;
292 }
293
294 /**
295  * Byte swaps config from cpu format to disk
296  *
297  * \param[in,out]       config          sk_keyfile_config to swap
298  */
299 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
300 {
301         int i;
302
303         if (!config)
304                 return;
305
306         config->skc_version = htobe32(config->skc_version);
307         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
308         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
309         config->skc_expire = htobe32(config->skc_expire);
310         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
311         config->skc_prime_bits = htobe32(config->skc_prime_bits);
312
313         for (i = 0; i < MAX_MGSNIDS; i++)
314                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
315 }
316
317 /**
318  * Byte swaps config from disk format to cpu
319  *
320  * \param[in,out]       config          sk_keyfile_config to swap
321  */
322 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
323 {
324         int i;
325
326         if (!config)
327                 return;
328
329         config->skc_version = be32toh(config->skc_version);
330         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
331         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
332         config->skc_expire = be32toh(config->skc_expire);
333         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
334         config->skc_prime_bits = be32toh(config->skc_prime_bits);
335
336         for (i = 0; i < MAX_MGSNIDS; i++)
337                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
338 }
339
340 /**
341  * Verifies the on key payload format is valid
342  *
343  * \param[in]   config          sk_keyfile_config
344  *
345  * \return      -1      failure
346  * \return      0       success
347  */
348 int sk_validate_config(const struct sk_keyfile_config *config)
349 {
350         int i;
351
352         if (!config) {
353                 printerr(0, "Null configuration passed\n");
354                 return -1;
355         }
356
357         if (config->skc_version != SK_CONF_VERSION) {
358                 printerr(0, "Invalid version\n");
359                 return -1;
360         }
361
362         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
363                 printerr(0, "Invalid HMAC algorithm\n");
364                 return -1;
365         }
366
367         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
368                 printerr(0, "Invalid crypt algorithm\n");
369                 return -1;
370         }
371
372         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
373                 /* Try to limit key expiration to some reasonable minimum and
374                  * also prevent values over INT_MAX because there appears
375                  * to be a type conversion issue */
376                 printerr(0, "Invalid expiration time should be between %d "
377                          "and %d\n", 60, INT_MAX);
378                 return -1;
379         }
380         if (config->skc_prime_bits % 8 != 0 ||
381             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
382                 printerr(0, "Invalid session key length must be a multiple of 8"
383                          " and less then %d bits\n",
384                          SK_MAX_P_BYTES * 8);
385                 return -1;
386         }
387         if (config->skc_shared_keylen % 8 != 0 ||
388             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
389                 printerr(0, "Invalid shared key max length must be a multiple "
390                          "of 8 and less then %d bits\n",
391                          SK_MAX_KEYLEN_BYTES * 8);
392                 return -1;
393         }
394
395         /* Check for terminating nulls on strings */
396         for (i = 0; i < sizeof(config->skc_fsname) &&
397              config->skc_fsname[i] != '\0';  i++)
398                 ; /* empty loop */
399         if (i == sizeof(config->skc_fsname)) {
400                 printerr(0, "File system name not null terminated\n");
401                 return -1;
402         }
403
404         for (i = 0; i < sizeof(config->skc_nodemap) &&
405              config->skc_nodemap[i] != '\0';  i++)
406                 ; /* empty loop */
407         if (i == sizeof(config->skc_nodemap)) {
408                 printerr(0, "Nodemap name not null terminated\n");
409                 return -1;
410         }
411
412         if (config->skc_type == SK_TYPE_INVALID) {
413                 printerr(0, "Invalid key type\n");
414                 return -1;
415         }
416
417         return 0;
418 }
419
420 /**
421  * Hashes \a string and places the hash in \a hash
422  * at \a hash
423  *
424  * \param[in]           string          Null terminated string to hash
425  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
426  * \param[in,out]       hash            gss_buffer_desc to hold the result
427  *
428  * \return      -1      failure
429  * \return      0       success
430  */
431 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
432                           gss_buffer_desc *hash)
433 {
434         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
435         size_t len = strlen(string);
436         unsigned int hashlen;
437
438         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
439                 goto out_err;
440         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
441                 goto out_err;
442         if (!EVP_DigestUpdate(ctx, string, len))
443                 goto out_err;
444         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
445                 goto out_err;
446
447         EVP_MD_CTX_destroy(ctx);
448         hash->length = hashlen;
449         return 0;
450
451 out_err:
452         EVP_MD_CTX_destroy(ctx);
453         return -1;
454 }
455
456 /**
457  * Hashes \a string and verifies the resulting hash matches the value
458  * in \a current_hash
459  *
460  * \param[in]           string          Null terminated string to hash
461  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
462  * \param[in,out]       current_hash    gss_buffer_desc to compare to
463  *
464  * \return      gss error       failure
465  * \return      GSS_S_COMPLETE  success
466  */
467 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
468                         const gss_buffer_desc *current_hash)
469 {
470         gss_buffer_desc hash;
471         unsigned char hashbuf[EVP_MAX_MD_SIZE];
472
473         hash.value = hashbuf;
474         hash.length = sizeof(hashbuf);
475
476         if (sk_hash_string(string, hash_alg, &hash))
477                 return GSS_S_FAILURE;
478         if (current_hash->length != hash.length)
479                 return GSS_S_DEFECTIVE_TOKEN;
480         if (memcmp(current_hash->value, hash.value, hash.length))
481                 return GSS_S_BAD_SIG;
482
483         return GSS_S_COMPLETE;
484 }
485
486 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
487                                        const char *mgsnid)
488 {
489         lnet_nid_t nid;
490         int i;
491
492         nid = libcfs_str2nid(mgsnid);
493         if (nid == LNET_NID_ANY)
494                 return 0;
495
496         for (i = 0; i < MAX_MGSNIDS; i++)
497                 if  (config->skc_mgsnids[i] == nid)
498                         return 1;
499         return 0;
500 }
501
502 /**
503  * Create an sk_cred structure populated with initial configuration info and the
504  * key.  \a tgt and \a nodemap are used in determining the expected key
505  * description so the key can be found by searching the keyring.
506  * This is done because there is no easy way to pass keys from the mount command
507  * all the way to the request_key call.  In addition any keys can be dynamically
508  * added to the keyrings and still found.  The keyring that needs to be used
509  * must be the session keyring.
510  *
511  * \param[in]   tgt             Target file system
512  * \param[in]   nodemap         Cluster name for the key.  This correlates to
513  *                              the nodemap name and is used by the server side.
514  *                              For the client this will be NULL.
515  * \param[in]   flags           Flags for the credentials
516  *
517  * \return      sk_cred Allocated struct sk_cred on success
518  * \return      NULL    failure
519  */
520 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
521                                const uint32_t flags)
522 {
523         struct sk_keyfile_config *config;
524         struct sk_kernel_ctx *kctx;
525         struct sk_cred *skc = NULL;
526         char description[SK_DESCRIPTION_SIZE + 1];
527         char fsname[MTI_NAME_MAXLEN + 1];
528         const char *mgsnid = NULL;
529         char *ptr;
530         long sk_key;
531         int keylen;
532         int len;
533         int rc;
534
535         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
536                  tgt, nodemap);
537
538         memset(description, 0, sizeof(description));
539         memset(fsname, 0, sizeof(fsname));
540
541         /* extract the file system name from target */
542         ptr = index(tgt, '-');
543         if (!ptr) {
544                 len = strlen(tgt);
545
546                 /* This must be an MGC target */
547                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
548                         printerr(0, "Invalid target name\n");
549                         return NULL;
550                 }
551                 mgsnid = tgt + 3;
552         } else {
553                 len = ptr - tgt;
554         }
555
556         if (len > MTI_NAME_MAXLEN) {
557                 printerr(0, "Invalid target name\n");
558                 return NULL;
559         }
560         memcpy(fsname, tgt, len);
561
562         if (nodemap) {
563                 if (mgsnid)
564                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
565                                       "lustre:MGS:%s", nodemap);
566                 else
567                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
568                                       "lustre:%s:%s", fsname, nodemap);
569         } else {
570                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
571                               fsname);
572         }
573
574         if (rc >= SK_DESCRIPTION_SIZE) {
575                 printerr(0, "Invalid key description\n");
576                 return NULL;
577         }
578
579         /* It may be a good idea to move Lustre keys to the gss_keyring
580          * (lgssc) type so that they expire when Lustre modules are removed.
581          * Unfortunately it can't be done at mount time because the mount
582          * syscall could trigger the Lustre modules to load and until that
583          * point we don't have a lgssc key type.
584          *
585          * TODO: Query the community for a consensus here  */
586         printerr(2, "Searching for key with description: %s\n", description);
587         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
588                                description, 0);
589         if (sk_key == -1) {
590                 printerr(1, "No key found for %s\n", description);
591                 return NULL;
592         }
593
594         keylen = keyctl_read_alloc(sk_key, (void **)&config);
595         if (keylen == -1) {
596                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
597                          strerror(errno));
598                 return NULL;
599         } else if (keylen != sizeof(*config)) {
600                 printerr(0, "Unexpected key size: %d returned for key %ld, "
601                          "expected %zu bytes\n",
602                          keylen, sk_key, sizeof(*config));
603                 goto out_err;
604         }
605
606         sk_config_disk_to_cpu(config);
607
608         if (sk_validate_config(config)) {
609                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
610                 goto out_err;
611         }
612
613         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
614                 printerr(0, "Target name does not match key's MGS NIDs\n");
615                 goto out_err;
616         }
617
618         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
619                 printerr(0, "Target name does not match key's file system\n");
620                 goto out_err;
621         }
622
623         skc = malloc(sizeof(*skc));
624         if (!skc) {
625                 printerr(0, "Failed to allocate memory for sk_cred\n");
626                 goto out_err;
627         }
628
629         /* this initializes all gss_buffer_desc to empty as well */
630         memset(skc, 0, sizeof(*skc));
631
632         skc->sc_flags = flags;
633         skc->sc_tgt.length = strlen(tgt) + 1;
634         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
635         if (!skc->sc_tgt.value) {
636                 printerr(0, "Failed to allocate memory for target\n");
637                 goto out_err;
638         }
639         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
640
641         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
642         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
643         if (!skc->sc_nodemap_hash.value) {
644                 printerr(0, "Failed to allocate memory for nodemap hash\n");
645                 goto out_err;
646         }
647
648         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
649                            &skc->sc_nodemap_hash)) {
650                 printerr(0, "Failed to generate hash for nodemap name\n");
651                 goto out_err;
652         }
653
654         kctx = &skc->sc_kctx;
655         kctx->skc_version = config->skc_version;
656         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
657         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
658         kctx->skc_expire = config->skc_expire;
659
660         /* key payload format is in bits, convert to bytes */
661         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
662         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
663         if (!kctx->skc_shared_key.value) {
664                 printerr(0, "Failed to allocate memory for shared key\n");
665                 goto out_err;
666         }
667         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
668                kctx->skc_shared_key.length);
669
670         skc->sc_p.length = config->skc_prime_bits / 8;
671         skc->sc_p.value = malloc(skc->sc_p.length);
672         if (!skc->sc_p.value) {
673                 printerr(0, "Failed to allocate p\n");
674                 goto out_err;
675         }
676         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
677
678         free(config);
679
680         return skc;
681
682 out_err:
683         sk_free_cred(skc);
684
685         free(config);
686         return NULL;
687 }
688
689 #define SK_GENERATOR 2
690 #define DH_NUMBER_ITERATIONS_FOR_PRIME 64
691
692 /* OpenSSL 1.1.1c increased the number of rounds used for Miller-Rabin testing
693  * of the prime provided as input parameter to DH_check(). This makes the check
694  * roughly x10 longer, and causes request timeouts when an SSK flavor is being
695  * used.
696  *
697  * Instead, use a dynamic number Miller-Rabin rounds based on the speed of the
698  * check on the current system, evaluated when the lsvcgssd daemon starts, but
699  * at least as many as OpenSSL 1.1.1b used for the same key size. If default
700  * DH_check() duration is OK, use it directly instead of limiting the rounds.
701  *
702  * If \a num_rounds == 0, we just call original DH_check() directly.
703  */
704 static bool sk_is_dh_valid(const DH *dh, int num_rounds)
705 {
706         const BIGNUM *p, *g;
707         BN_ULONG word;
708         BN_CTX *ctx;
709         BIGNUM *r;
710         bool valid = false;
711         int rc;
712
713         if (num_rounds == 0) {
714                 int codes = 0;
715
716                 rc = DH_check(dh, &codes);
717                 if (rc != 1 || codes) {
718                         printerr(0, "DH_check(0) failed: codes=%#x: rc=%d\n",
719                                  codes, rc);
720                         return false;
721                 }
722                 return true;
723         }
724
725         DH_get0_pqg(dh, &p, NULL, &g);
726
727         if (!BN_is_word(g, SK_GENERATOR)) {
728                 printerr(0, "%s: Diffie-Hellman generator is not %u\n",
729                          program_invocation_short_name, SK_GENERATOR);
730                 return false;
731         }
732
733         word = BN_mod_word(p, 24);
734         if (word != 11) {
735                 printerr(0, "%s: Diffie-Hellman prime modulo=%lu unsuitable\n",
736                          program_invocation_short_name, word);
737                 return false;
738         }
739
740         ctx = BN_CTX_new();
741         if (ctx == NULL) {
742                 printerr(0, "%s: Diffie-Hellman error allocating context\n",
743                          program_invocation_short_name);
744                 return false;
745         }
746         BN_CTX_start(ctx);
747         r = BN_CTX_get(ctx); /* must be called before "ctx" used elsewhere */
748
749         rc = BN_is_prime_ex(p, num_rounds, ctx, NULL);
750         if (rc == 0)
751                 printerr(0, "%s: Diffie-Hellman 'p' not prime in %u rounds\n",
752                          program_invocation_short_name, num_rounds);
753         if (rc <= 0)
754                 goto out_free;
755
756         if (!BN_rshift1(r, p)) {
757                 printerr(0, "%s: error shifting BigNum 'r' by 'p'\n",
758                          program_invocation_short_name);
759                 goto out_free;
760         }
761         rc = BN_is_prime_ex(r, num_rounds, ctx, NULL);
762         if (rc == 0)
763                 printerr(0, "%s: Diffie-Hellman 'r' not prime in %u rounds\n",
764                          program_invocation_short_name, num_rounds);
765         if (rc <= 0)
766                 goto out_free;
767
768         valid = true;
769
770 out_free:
771         BN_CTX_end(ctx);
772         BN_CTX_free(ctx);
773
774         return valid;
775 }
776
777 #define VALUE_LENGTH 256
778 static unsigned char test_prime[VALUE_LENGTH] =
779         "\xf7\xfa\x49\xd8\xec\xb1\x3b\xff\x26\x10\x3f\xc5\x3a\xc5\xcc\x40"
780         "\x4f\xbf\x92\xe1\x8b\x83\xe7\xa2\xba\x0f\x51\x5a\x91\x48\xe0\xa3"
781         "\xf1\x4d\xbc\xbb\x8a\x28\x14\xac\x02\x23\x76\x42\x17\x4d\x3c\xdc"
782         "\x5e\x4f\x80\x1f\xd7\x54\x1c\x50\xac\x3b\x28\x68\x8d\x71\x41\x7f"
783         "\xa7\x1c\x2f\x22\xd3\xa8\x91\xb2\x64\xb6\x84\xa6\xcf\x06\x16\x91"
784         "\x2f\xb8\xb4\x42\x1d\x3a\x4e\x3a\x0c\x7f\x04\x69\x78\xb5\x8f\x92"
785         "\x07\x89\xac\x24\x06\x53\x2c\x23\xec\xaa\x5c\xb4\x7b\x49\xbc\xf4"
786         "\x90\x67\x71\x9c\x24\x2c\x1d\x8d\x76\xc8\x85\x4e\x19\xf1\xf9\x33"
787         "\x45\xbd\x9f\x7d\x0a\x08\x8c\x22\xcc\x35\xf3\x5b\xab\x3f\x24\x9d"
788         "\x61\x70\x86\xbb\xbe\xd8\xb0\xf8\x34\xfa\xeb\x5b\x8e\xf2\x62\x23"
789         "\xd1\xfb\xbb\xb8\x21\x71\x1e\x39\x39\x59\xe0\x82\x98\x41\x84\x40"
790         "\x1f\xd3\x9b\xa3\x73\xdb\xec\xe0\xc0\xde\x2d\x1c\xea\x43\x40\x93"
791         "\x98\x38\x03\x36\x1e\xe1\xe7\x39\x7b\x35\x92\x4a\x51\xa5\x91\x63"
792         "\xd5\x31\x98\x3d\x89\x27\x6f\xcc\x69\xff\xbe\x31\x13\xdc\x2f\x72"
793         "\x2d\xab\x6a\xb7\x13\xd3\x47\xda\xaa\xf3\x3c\xa0\xfd\xaa\x0f\x02"
794         "\x96\x81\x1a\x26\xe8\xf7\x25\x65\x33\x78\xd9\x6b\x6d\xb0\xd9\xfb";
795
796 /**
797  * Measure time taken by prime testing routine for a 2048 bit long prime,
798  * depending on the number of check rounds.
799  *
800  * \param[in]   usec_check_max    max time allowed for DH_check completion
801  *
802  * \retval      max number of rounds to keep prime testing under usec_check_max
803  *              return 0 if we should use the default DH_check rounds
804  */
805 int sk_speedtest_dh_valid(unsigned int usec_check_max)
806 {
807         DH *dh;
808         BIGNUM *p, *g;
809         int num_rounds, prev_rounds = 0;
810
811         dh = DH_new();
812         if (!dh)
813                 return 0;
814
815         p = BN_bin2bn(test_prime, VALUE_LENGTH, NULL);
816         if (!p)
817                 goto free_dh;
818
819         g = BN_new();
820         if (!g)
821                 goto free_p;
822
823         if (!BN_set_word(g, SK_GENERATOR))
824                 goto free_g;
825
826         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
827         if (!DH_set0_pqg(dh, p, NULL, g)) {
828         free_g:
829                 BN_free(g);
830         free_p:
831                 BN_free(p);
832                 goto free_dh;
833         }
834
835         for (num_rounds = 0;
836              num_rounds <= DH_NUMBER_ITERATIONS_FOR_PRIME;
837              num_rounds += (num_rounds <= 4 ? 4 : 8)) {
838                 unsigned int usec_this;
839                 int j;
840
841                 /* get max duration of 4 runs at current number of rounds */
842                 usec_this = 0;
843                 for (j = 0; j < 4; j++) {
844                         struct timeval now, prev;
845                         unsigned int usec_curr;
846
847                         gettimeofday(&prev, NULL);
848                         if (!sk_is_dh_valid(dh, num_rounds)) {
849                                 /* if test_prime is found bad, use default */
850                                 prev_rounds = 0;
851                                 goto free_dh;
852                         }
853                         gettimeofday(&now, NULL);
854                         usec_curr = (now.tv_sec - prev.tv_sec) * 1000000 +
855                                     now.tv_usec - prev.tv_usec;
856                         if (usec_curr > usec_this)
857                                 usec_this = usec_curr;
858                 }
859                 printerr(2, "%s: %d rounds: %d usec\n",
860                          program_invocation_short_name, num_rounds, usec_this);
861                 if (num_rounds == 0) {
862                         if (usec_this <= usec_check_max)
863                         /* using original check rounds as implemented in
864                          * DH_check() took less time than the max allowed,
865                          * so just use original DH_check()
866                          */
867                                 break;
868                 } else if (usec_this >= usec_check_max) {
869                         break;
870                 }
871                 prev_rounds = num_rounds;
872         }
873
874 free_dh:
875         DH_free(dh);
876
877         return prev_rounds;
878 }
879
880 /**
881  * Populates the DH parameters for the DHKE
882  *
883  * \param[in,out]       skc             Shared key credentials structure to
884  *                                      populate with DH parameters
885  *
886  * \retval      GSS_S_COMPLETE  success
887  * \retval      GSS_S_FAILURE   failure
888  */
889 uint32_t sk_gen_params(struct sk_cred *skc, int num_rounds)
890 {
891         uint32_t random;
892         BIGNUM *p, *g;
893         const BIGNUM *pub_key;
894
895         /* Random value used by both the request and response as part of the
896          * key binding material.  This also should ensure we have unqiue
897          * tokens that are sent to the remote server which is important because
898          * the token is hashed for the sunrpc cache lookups and a failure there
899          * would cause connection attempts to fail indefinitely due to the large
900          * timeout value on the server side */
901         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
902                 printerr(0, "Failed to get data for random parameter: %s\n",
903                          ERR_error_string(ERR_get_error(), NULL));
904                 return GSS_S_FAILURE;
905         }
906
907         /* The random value will always be used in byte range operations
908          * so we keep it as big endian from this point on */
909         skc->sc_kctx.skc_host_random = random;
910
911         /* Populate DH parameters */
912         skc->sc_params = DH_new();
913         if (!skc->sc_params) {
914                 printerr(0, "Failed to allocate DH\n");
915                 return GSS_S_FAILURE;
916         }
917
918         p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
919         if (!p) {
920                 printerr(0, "Failed to convert binary to BIGNUM\n");
921                 return GSS_S_FAILURE;
922         }
923
924         /* We use a static generator for shared key */
925         g = BN_new();
926         if (!g) {
927                 printerr(0, "Failed to allocate new BIGNUM\n");
928                 return GSS_S_FAILURE;
929         }
930         if (BN_set_word(g, SK_GENERATOR) != 1) {
931                 printerr(0, "Failed to set g value for DH params\n");
932                 return GSS_S_FAILURE;
933         }
934
935         if (!DH_set0_pqg(skc->sc_params, p, NULL, g)) {
936                 printerr(0, "Failed to set pqg\n");
937                 return GSS_S_FAILURE;
938         }
939
940         /* Verify that we have a safe prime and valid generator */
941         if (!sk_is_dh_valid(skc->sc_params, num_rounds))
942                 return GSS_S_FAILURE;
943
944         if (DH_generate_key(skc->sc_params) != 1) {
945                 printerr(0, "Failed to generate public DH key: %s\n",
946                          ERR_error_string(ERR_get_error(), NULL));
947                 return GSS_S_FAILURE;
948         }
949
950         DH_get0_key(skc->sc_params, &pub_key, NULL);
951         skc->sc_pub_key.length = BN_num_bytes(pub_key);
952         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
953         if (!skc->sc_pub_key.value) {
954                 printerr(0, "Failed to allocate memory for public key\n");
955                 return GSS_S_FAILURE;
956         }
957
958         BN_bn2bin(pub_key, skc->sc_pub_key.value);
959
960         return GSS_S_COMPLETE;
961 }
962
963 /**
964  * Convert SK hash algorithm into openssl message digest
965  *
966  * \param[in,out]       alg             SK hash algorithm
967  *
968  * \retval              EVP_MD
969  */
970 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
971 {
972         switch (alg) {
973         case CFS_HASH_ALG_SHA256:
974                 return EVP_sha256();
975         case CFS_HASH_ALG_SHA512:
976                 return EVP_sha512();
977         default:
978                 return EVP_md_null();
979         }
980 }
981
982 /**
983  * Signs (via HMAC) the parameters used only in the key initialization protocol.
984  *
985  * \param[in]           key             Key to use for HMAC
986  * \param[in]           bufs            Array of gss_buffer_desc to generate
987  *                                      HMAC for
988  * \param[in]           numbufs         Number of buffers in array
989  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
990  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
991  *                                      in this gss_buffer_desc.  Caller must
992  *                                      free this.
993  *
994  * \retval      0       success
995  * \retval      -1      failure
996  */
997 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
998                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
999 {
1000         HMAC_CTX *hctx;
1001         unsigned int hashlen = EVP_MD_size(hash_alg);
1002         int i;
1003         int rc = -1;
1004
1005         if (hash_alg == EVP_md_null()) {
1006                 printerr(0, "Invalid hash algorithm\n");
1007                 return -1;
1008         }
1009
1010         hctx = HMAC_CTX_new();
1011
1012         hmac->length = hashlen;
1013         hmac->value = malloc(hashlen);
1014         if (!hmac->value) {
1015                 printerr(0, "Failed to allocate memory for HMAC\n");
1016                 goto out;
1017         }
1018
1019         if (HMAC_Init_ex(hctx, key->value, key->length, hash_alg, NULL) != 1) {
1020                 printerr(0, "Failed to init HMAC\n");
1021                 goto out;
1022         }
1023
1024         for (i = 0; i < numbufs; i++) {
1025                 if (HMAC_Update(hctx, bufs[i].value, bufs[i].length) != 1) {
1026                         printerr(0, "Failed to update HMAC\n");
1027                         goto out;
1028                 }
1029         }
1030
1031         /* The result gets populated in hmac */
1032         if (HMAC_Final(hctx, hmac->value, &hashlen) != 1) {
1033                 printerr(0, "Failed to finalize HMAC\n");
1034                 goto out;
1035         }
1036
1037         if (hmac->length != hashlen) {
1038                 printerr(0, "HMAC size does not match expected\n");
1039                 goto out;
1040         }
1041
1042         rc = 0;
1043 out:
1044         HMAC_CTX_free(hctx);
1045         return rc;
1046 }
1047
1048 /**
1049  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
1050  * and verifies against \a hmac.
1051  *
1052  * \param[in]   skc             Shared key credentials
1053  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
1054  * \param[in]   numbufs         Number of buffers in array
1055  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
1056  * \param[in]   hmac            HMAC to verify against
1057  *
1058  * \retval      GSS_S_COMPLETE  success (match)
1059  * \retval      gss error       failure
1060  */
1061 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
1062                         const int numbufs, const EVP_MD *hash_alg,
1063                         gss_buffer_desc *hmac)
1064 {
1065         gss_buffer_desc bufs_hmac;
1066         int rc;
1067
1068         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
1069                          &bufs_hmac)) {
1070                 printerr(0, "Failed to sign buffers to verify HMAC\n");
1071                 if (bufs_hmac.value)
1072                         free(bufs_hmac.value);
1073                 return GSS_S_FAILURE;
1074         }
1075
1076         if (hmac->length != bufs_hmac.length) {
1077                 printerr(0, "Invalid HMAC size\n");
1078                 free(bufs_hmac.value);
1079                 return GSS_S_BAD_SIG;
1080         }
1081
1082         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
1083         free(bufs_hmac.value);
1084
1085         if (rc)
1086                 return GSS_S_BAD_SIG;
1087
1088         return GSS_S_COMPLETE;
1089 }
1090
1091 /**
1092  * Cleanup an sk_cred freeing any resources
1093  *
1094  * \param[in,out]       skc     Shared key credentials to free
1095  */
1096 void sk_free_cred(struct sk_cred *skc)
1097 {
1098         if (!skc)
1099                 return;
1100
1101         if (skc->sc_p.value)
1102                 free(skc->sc_p.value);
1103         if (skc->sc_pub_key.value)
1104                 free(skc->sc_pub_key.value);
1105         if (skc->sc_tgt.value)
1106                 free(skc->sc_tgt.value);
1107         if (skc->sc_nodemap_hash.value)
1108                 free(skc->sc_nodemap_hash.value);
1109         if (skc->sc_hmac.value)
1110                 free(skc->sc_hmac.value);
1111
1112         /* Overwrite keys and IV before freeing */
1113         if (skc->sc_dh_shared_key.value) {
1114                 memset(skc->sc_dh_shared_key.value, 0,
1115                        skc->sc_dh_shared_key.length);
1116                 free(skc->sc_dh_shared_key.value);
1117         }
1118         if (skc->sc_kctx.skc_hmac_key.value) {
1119                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
1120                        skc->sc_kctx.skc_hmac_key.length);
1121                 free(skc->sc_kctx.skc_hmac_key.value);
1122         }
1123         if (skc->sc_kctx.skc_encrypt_key.value) {
1124                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
1125                        skc->sc_kctx.skc_encrypt_key.length);
1126                 free(skc->sc_kctx.skc_encrypt_key.value);
1127         }
1128         if (skc->sc_kctx.skc_shared_key.value) {
1129                 memset(skc->sc_kctx.skc_shared_key.value, 0,
1130                        skc->sc_kctx.skc_shared_key.length);
1131                 free(skc->sc_kctx.skc_shared_key.value);
1132         }
1133         if (skc->sc_kctx.skc_session_key.value) {
1134                 memset(skc->sc_kctx.skc_session_key.value, 0,
1135                        skc->sc_kctx.skc_session_key.length);
1136                 free(skc->sc_kctx.skc_session_key.value);
1137         }
1138
1139         if (skc->sc_params)
1140                 DH_free(skc->sc_params);
1141
1142         free(skc);
1143         skc = NULL;
1144 }
1145
1146 /* This function handles key derivation using the hash algorithm specified in
1147  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
1148  * \a origin_key to produce a \a derived_key.  The first element of the
1149  * key_binding_bufs array is reserved for the counter used in the KDF.  The
1150  * derived key in \a derived_key could differ in size from \a origin_key and
1151  * must be populated with the expected size and a valid buffer to hold the
1152  * contents.
1153  *
1154  * If the derived key size is greater than the HMAC algorithm size it will be
1155  * a done using several iterations of a counter and the key binding bufs.
1156  *
1157  * If the size is smaller it will take copy the first N bytes necessary to
1158  * fill the derived key. */
1159 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
1160            gss_buffer_desc *key_binding_bufs, int numbufs,
1161            enum cfs_crypto_hash_alg hmac_alg)
1162 {
1163         size_t remain;
1164         size_t bytes;
1165         uint32_t counter;
1166         char *keydata;
1167         gss_buffer_desc tmp_hash;
1168         int i;
1169         int rc;
1170
1171         if (numbufs < 1)
1172                 return -EINVAL;
1173
1174         /* Use a counter as the first buffer followed by the key binding
1175          * buffers in the event we need more than one a single cycle to
1176          * produced a symmetric key large enough in size */
1177         key_binding_bufs[0].value = &counter;
1178         key_binding_bufs[0].length = sizeof(counter);
1179
1180         remain = derived_key->length;
1181         keydata = derived_key->value;
1182         i = 0;
1183         while (remain > 0) {
1184                 counter = htobe32(i++);
1185                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1186                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1187                 if (rc) {
1188                         if (tmp_hash.value)
1189                                 free(tmp_hash.value);
1190                         return rc;
1191                 }
1192
1193                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
1194                         free(tmp_hash.value);
1195                         return -EINVAL;
1196                 }
1197
1198                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1199                 memcpy(keydata, tmp_hash.value, bytes);
1200                 free(tmp_hash.value);
1201                 remain -= bytes;
1202                 keydata += bytes;
1203         }
1204
1205         return 0;
1206 }
1207
1208 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1209  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1210  * (Sep 2014) Section 5.5.1
1211  *
1212  * \param[in,out]       skc             Shared key credentials structure with
1213  *
1214  * \return      -1              failure
1215  * \return      0               success
1216  */
1217 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1218                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1219 {
1220         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1221         gss_buffer_desc *session_key = &kctx->skc_session_key;
1222         gss_buffer_desc bufs[5];
1223         enum cfs_crypto_crypt_alg crypt_alg;
1224         int rc = -1;
1225
1226         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1227         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1228         session_key->value = malloc(session_key->length);
1229         if (!session_key->value) {
1230                 printerr(0, "Failed to allocate memory for session key\n");
1231                 return rc;
1232         }
1233
1234         /* Key binding info ordering
1235          * 1. Reserved for counter
1236          * 1. DH shared key
1237          * 2. Client's NIDs
1238          * 3. Client's token
1239          * 4. Server's token */
1240         bufs[0].value = NULL;
1241         bufs[0].length = 0;
1242         bufs[1] = skc->sc_dh_shared_key;
1243         bufs[2].value = &client_nid;
1244         bufs[2].length = sizeof(client_nid);
1245         bufs[3] = *client_token;
1246         bufs[4] = *server_token;
1247
1248         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1249                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1250 }
1251
1252 /* Uses the session key to create an HMAC key and encryption key.  In
1253  * integrity mode the session key used to generate the HMAC key uses
1254  * session information which is available on the wire but by creating
1255  * a session based HMAC key we can prevent potential replay as both the
1256  * client and server have random numbers used as part of the key creation.
1257  *
1258  * The keys used for integrity and privacy are formulated as below using
1259  * the session key that is the output of the key derivation function.  The
1260  * HMAC algorithm is determined by the shared key algorithm selected in the
1261  * key file.
1262  *
1263  * For ski mode:
1264  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1265  *
1266  * For skpi mode:
1267  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1268  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1269  *
1270  * \param[in,out]       skc             Shared key credentials structure with
1271  *
1272  * \return      -1              failure
1273  * \return      0               success
1274  */
1275 int sk_compute_keys(struct sk_cred *skc)
1276 {
1277         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1278         gss_buffer_desc *session_key = &kctx->skc_session_key;
1279         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1280         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1281         enum cfs_crypto_hash_alg hmac_alg;
1282         enum cfs_crypto_crypt_alg crypt_alg;
1283         char *encrypt = "Encrypt";
1284         char *integrity = "Integrity";
1285         int rc;
1286
1287         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1288         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1289         hmac_key->value = malloc(hmac_key->length);
1290         if (!hmac_key->value)
1291                 return -ENOMEM;
1292
1293         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1294                                session_key->length, SK_PBKDF2_ITERATIONS,
1295                                sk_hash_to_evp_md(hmac_alg),
1296                                hmac_key->length, hmac_key->value);
1297         if (rc == 0)
1298                 return -EINVAL;
1299
1300         /* Encryption key is only populated in privacy mode */
1301         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1302                 return 0;
1303
1304         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1305         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1306         encrypt_key->value = malloc(encrypt_key->length);
1307         if (!encrypt_key->value)
1308                 return -ENOMEM;
1309
1310         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1311                                session_key->length, SK_PBKDF2_ITERATIONS,
1312                                sk_hash_to_evp_md(hmac_alg),
1313                                encrypt_key->length, encrypt_key->value);
1314         if (rc == 0)
1315                 return -EINVAL;
1316
1317         return 0;
1318 }
1319
1320 /**
1321  * Computes a session key based on the DH parameters from the host and its peer
1322  *
1323  * \param[in,out]       skc             Shared key credentials structure with
1324  *                                      the session key populated with the
1325  *                                      compute key
1326  * \param[in]           pub_key         Public key returned from peer in
1327  *                                      gss_buffer_desc
1328  * \return      gss error               failure
1329  * \return      GSS_S_COMPLETE          success
1330  */
1331 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1332 {
1333         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1334         BIGNUM *remote_pub_key;
1335         int status;
1336         uint32_t rc = GSS_S_FAILURE;
1337
1338         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1339         if (!remote_pub_key) {
1340                 printerr(0, "Failed to convert binary to BIGNUM\n");
1341                 return rc;
1342         }
1343
1344         dh_shared->length = DH_size(skc->sc_params);
1345         dh_shared->value = malloc(dh_shared->length);
1346         if (!dh_shared->value) {
1347                 printerr(0, "Failed to allocate memory for computed shared "
1348                          "secret key\n");
1349                 goto out_err;
1350         }
1351
1352         /* This compute the shared key from the DHKE */
1353         status = DH_compute_key(dh_shared->value, remote_pub_key,
1354                                 skc->sc_params);
1355         if (status == -1) {
1356                 printerr(0, "DH_compute_key() failed: %s\n",
1357                          ERR_error_string(ERR_get_error(), NULL));
1358                 goto out_err;
1359         } else if (status < dh_shared->length) {
1360                 /* there is around 1 chance out of 256 that the returned
1361                  * shared key is shorter than expected
1362                  */
1363                 if (status >= dh_shared->length - 2) {
1364                         int shift = dh_shared->length - status;
1365                         /* if the key is short by only 1 or 2 bytes, just
1366                          * prepend it with 0s
1367                          */
1368                         memmove((void *)(dh_shared->value + shift),
1369                                 dh_shared->value, status);
1370                         memset(dh_shared->value, 0, shift);
1371                 } else {
1372                         /* if the key is really too short, return GSS_S_BAD_QOP
1373                          * so that the caller can retry to generate
1374                          */
1375                         printerr(0, "DH_compute_key() returned a short key of %d bytes, expected: %zu\n",
1376                                  status, dh_shared->length);
1377                         rc = GSS_S_BAD_QOP;
1378                         goto out_err;
1379                 }
1380         }
1381
1382         rc = GSS_S_COMPLETE;
1383
1384 out_err:
1385         BN_free(remote_pub_key);
1386         return rc;
1387 }
1388
1389 /**
1390  * Creates a serialized buffer for the kernel in the order of struct
1391  * sk_kernel_ctx.
1392  *
1393  * \param[in,out]       skc             Shared key credentials structure
1394  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1395  *                                      Caller must free this buffer.
1396  *
1397  * \return      0       success
1398  * \return      -1      failure
1399  */
1400 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1401 {
1402         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1403         char *p, *end;
1404         size_t bufsize;
1405
1406         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1407                   kctx->skc_encrypt_key.length;
1408
1409         ctx_token->value = malloc(bufsize);
1410         if (!ctx_token->value)
1411                 return -1;
1412         ctx_token->length = bufsize;
1413
1414         p = ctx_token->value;
1415         end = p + ctx_token->length;
1416
1417         if (WRITE_BYTES(&p, end, kctx->skc_version))
1418                 return -1;
1419         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1420                 return -1;
1421         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1422                 return -1;
1423         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1424                 return -1;
1425         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1426                 return -1;
1427         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1428                 return -1;
1429         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1430                 return -1;
1431         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1432                 return -1;
1433
1434         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1435
1436         return 0;
1437 }
1438
1439 /**
1440  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1441  * up to \a numbufs.  Memory is allocated for each value and length
1442  * will be populated with the length
1443  *
1444  * \param[in,out]       bufs    Array of gss_buffer_descs
1445  * \param[in,out]       numbufs number of gss_buffer_desc in array
1446  * \param[in]           ns      netstring to decode
1447  *
1448  * \return      buffers populated       success
1449  * \return      -1                      failure
1450  */
1451 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1452 {
1453         char *ptr = ns->value;
1454         size_t remain = ns->length;
1455         unsigned int size;
1456         int digits;
1457         int sep;
1458         int rc;
1459         int i;
1460
1461         for (i = 0; i < numbufs; i++) {
1462                 /* read the size of first buffer */
1463                 rc = sscanf(ptr, "%9u", &size);
1464                 if (rc < 1)
1465                         goto out_err;
1466                 digits = (size) ? ceil(log10(size + 1)) : 1;
1467
1468                 /* sep of current string */
1469                 sep = size + digits + 2;
1470
1471                 /* check to make sure it's valid */
1472                 if (remain < sep || ptr[digits] != ':' ||
1473                     ptr[sep - 1] != ',')
1474                         goto out_err;
1475
1476                 bufs[i].length = size;
1477                 if (size == 0) {
1478                         bufs[i].value = NULL;
1479                 } else {
1480                         bufs[i].value = malloc(size);
1481                         if (!bufs[i].value)
1482                                 goto out_err;
1483                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1484                 }
1485
1486                 remain -= sep;
1487                 ptr += sep;
1488         }
1489
1490         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1491         return i;
1492
1493 out_err:
1494         while (i-- > 0) {
1495                 if (bufs[i].value)
1496                         free(bufs[i].value);
1497                 bufs[i].length = 0;
1498         }
1499         return -1;
1500 }
1501
1502 /**
1503  * Creates a netstring in a gss_buffer_desc that consists of all
1504  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1505  * as binary as it can contain null characters.
1506  *
1507  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1508  * \param[in]           numbufs         Number of buffers in array
1509  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1510  *                                      netstring
1511  *
1512  * \return      -1      failure
1513  * \return      0       success
1514  */
1515 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1516                         gss_buffer_desc *ns)
1517 {
1518         unsigned char *ptr;
1519         int size = 0;
1520         int rc;
1521         int i;
1522
1523         /* size of string in decimal, string size, colon, and comma */
1524         for (i = 0; i < numbufs; i++) {
1525
1526                 if (bufs[i].length == 0)
1527                         size += 3;
1528                 else
1529                         size += ceil(log10(bufs[i].length + 1)) +
1530                                 bufs[i].length + 2;
1531         }
1532
1533         ns->length = size;
1534         ns->value = malloc(ns->length);
1535         if (!ns->value) {
1536                 ns->length = 0;
1537                 return -1;
1538         }
1539
1540         ptr = ns->value;
1541         for (i = 0; i < numbufs; i++) {
1542                 /* size */
1543                 rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
1544                 ptr += rc;
1545
1546                 /* contents */
1547                 memcpy(ptr, bufs[i].value, bufs[i].length);
1548                 ptr += bufs[i].length;
1549
1550                 /* delimeter */
1551                 *ptr++ = ',';
1552
1553                 size -= bufs[i].length + rc + 1;
1554
1555                 /* should not happen */
1556                 if (size < 0)
1557                         abort();
1558         }
1559
1560         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1561         return 0;
1562 }