Whamcloud - gitweb
LU-14479 ssk: explicitly set perm on key
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42 #include <libcfs/util/string.h>
43 #include <sys/time.h>
44
45 #include "sk_utils.h"
46 #include "write_bytes.h"
47
48 #define SK_PBKDF2_ITERATIONS 10000
49
50 #ifdef _NEW_BUILD_
51 # include "lgss_utils.h"
52 #else
53 # include "gss_util.h"
54 # include "gss_oids.h"
55 # include "err_util.h"
56 #endif
57
58 #ifdef _ERR_UTIL_H_
59 /**
60  * Initializes logging
61  * \param[in]   program         Program name to output
62  * \param[in]   verbose         Verbose flag
63  * \param[in]   fg              Whether or not to run in foreground
64  *
65  */
66 void sk_init_logging(char *program, int verbose, int fg)
67 {
68         initerr(program, verbose, fg);
69 }
70 #endif
71
72 /**
73  * Loads the key from \a filename and returns the struct sk_keyfile_config.
74  * It should be freed by the caller.
75  *
76  * \param[in]   filename                Disk or key payload data
77  *
78  * \return      sk_keyfile_config       sucess
79  * \return      NULL                    failure
80  */
81 struct sk_keyfile_config *sk_read_file(char *filename)
82 {
83         struct sk_keyfile_config *config;
84         char *ptr;
85         size_t rc;
86         size_t remain;
87         int fd;
88
89         config = malloc(sizeof(*config));
90         if (!config) {
91                 printerr(0, "Failed to allocate memory for config\n");
92                 return NULL;
93         }
94
95         /* allow standard input override */
96         if (strcmp(filename, "-") == 0)
97                 fd = STDIN_FILENO;
98         else
99                 fd = open(filename, O_RDONLY);
100
101         if (fd == -1) {
102                 printerr(0, "Error opening key file '%s': %s\n", filename,
103                          strerror(errno));
104                 goto out_free;
105         } else if (fd != STDIN_FILENO) {
106                 struct stat st;
107
108                 rc = fstat(fd, &st);
109                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
110                         fprintf(stderr, "warning: "
111                                 "secret key '%s' has insecure file mode %#o\n",
112                                 filename, st.st_mode);
113         }
114
115         ptr = (char *)config;
116         remain = sizeof(*config);
117         while (remain > 0) {
118                 rc = read(fd, ptr, remain);
119                 if (rc == -1) {
120                         if (errno == EINTR)
121                                 continue;
122                         printerr(0, "read() failed on %s: %s\n", filename,
123                                  strerror(errno));
124                         goto out_close;
125                 } else if (rc == 0) {
126                         printerr(0, "File %s does not have a complete key\n",
127                                  filename);
128                         goto out_close;
129                 }
130                 ptr += rc;
131                 remain -= rc;
132         }
133
134         if (fd != STDIN_FILENO)
135                 close(fd);
136         sk_config_disk_to_cpu(config);
137         return config;
138
139 out_close:
140         close(fd);
141 out_free:
142         free(config);
143         return NULL;
144 }
145
146 /**
147  * Checks if a key matching \a description is found in the keyring for
148  * logging purposes and then attempts to load \a payload of \a psize into a key
149  * with \a description.
150  *
151  * \param[in]   payload         Key payload
152  * \param[in]   psize           Payload size
153  * \param[in]   description     Description used for key in keyring
154  *
155  * \return      0       sucess
156  * \return      -1      failure
157  */
158 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
159                                 const char *description)
160 {
161         struct sk_keyfile_config payload;
162         key_serial_t key;
163
164         memcpy(&payload, skc, sizeof(*skc));
165
166         /* In the keyring use the disk layout so keyctl pipe can be used */
167         sk_config_cpu_to_disk(&payload);
168
169         /* Check to see if a key is already loaded matching description */
170         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
171         if (key != -1)
172                 printerr(2, "Key %d found in session keyring, replacing\n",
173                          key);
174
175         key = add_key("user", description, &payload, sizeof(payload),
176                       KEY_SPEC_USER_KEYRING);
177         if (key != -1) {
178                 key_perm_t perm = KEY_POS_ALL | KEY_USR_ALL |
179                         KEY_GRP_ALL | KEY_OTH_ALL;
180
181                 if (keyctl_setperm(key, perm) < 0)
182                         printerr(2, "Failed to set perm 0x%x on key %d\n",
183                                  perm, key);
184                 printerr(2, "Added key %d with description %s\n", key,
185                          description);
186         } else {
187                 printerr(0, "Failed to add key with %s\n", description);
188         }
189
190         return key;
191 }
192
193 /**
194  * Reads the key from \a path, verifies it and loads into the session keyring
195  * using a description determined by the the \a type.  Existing keys with the
196  * same description are replaced.
197  *
198  * \param[in]   path    Path to key file
199  * \param[in]   type    Type of key to load which determines the description
200  *
201  * \return      0       sucess
202  * \return      -1      failure
203  */
204 int sk_load_keyfile(char *path)
205 {
206         struct sk_keyfile_config *config;
207         char description[SK_DESCRIPTION_SIZE + 1];
208         struct stat buf;
209         int i;
210         int rc;
211         int rc2 = -1;
212
213         rc = stat(path, &buf);
214         if (rc == -1) {
215                 printerr(0, "stat() failed for file %s: %s\n", path,
216                          strerror(errno));
217                 return rc2;
218         }
219
220         config = sk_read_file(path);
221         if (!config)
222                 return rc2;
223
224         /* Similar to ssh, require adequate care of key files */
225         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
226                 printerr(0, "Shared key files must be read/writeable only by "
227                          "owner\n");
228                 return -1;
229         }
230
231         if (sk_validate_config(config))
232                 goto out;
233
234         /* The server side can have multiple key files per file system so
235          * the nodemap name is appended to the key description to uniquely
236          * identify it */
237         if (config->skc_type & SK_TYPE_MGS) {
238                 /* Any key can be an MGS key as long as we are told to use it */
239                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
240                               config->skc_nodemap);
241                 if (rc >= SK_DESCRIPTION_SIZE)
242                         goto out;
243                 if (sk_load_key(config, description) == -1)
244                         goto out;
245         }
246         if (config->skc_type & SK_TYPE_SERVER) {
247                 /* Server keys need to have the file system name in the key */
248                 if (!config->skc_fsname) {
249                         printerr(0, "Key configuration has no file system "
250                                  "attribute.  Can't load as server type\n");
251                         goto out;
252                 }
253                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
254                               config->skc_fsname, config->skc_nodemap);
255                 if (rc >= SK_DESCRIPTION_SIZE)
256                         goto out;
257                 if (sk_load_key(config, description) == -1)
258                         goto out;
259         }
260         if (config->skc_type & SK_TYPE_CLIENT) {
261                 /* Load client file system key */
262                 if (config->skc_fsname) {
263                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
264                                       "lustre:%s", config->skc_fsname);
265                         if (rc >= SK_DESCRIPTION_SIZE)
266                                 goto out;
267                         if (sk_load_key(config, description) == -1)
268                                 goto out;
269                 }
270
271                 /* Load client MGC keys */
272                 for (i = 0; i < MAX_MGSNIDS; i++) {
273                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
274                                 continue;
275                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
276                                       "lustre:MGC%s",
277                                       libcfs_nid2str(config->skc_mgsnids[i]));
278                         if (rc >= SK_DESCRIPTION_SIZE)
279                                 goto out;
280                         if (sk_load_key(config, description) == -1)
281                                 goto out;
282                 }
283         }
284
285         rc2 = 0;
286
287 out:
288         free(config);
289         return rc2;
290 }
291
292 /**
293  * Byte swaps config from cpu format to disk
294  *
295  * \param[in,out]       config          sk_keyfile_config to swap
296  */
297 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
298 {
299         int i;
300
301         if (!config)
302                 return;
303
304         config->skc_version = htobe32(config->skc_version);
305         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
306         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
307         config->skc_expire = htobe32(config->skc_expire);
308         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
309         config->skc_prime_bits = htobe32(config->skc_prime_bits);
310
311         for (i = 0; i < MAX_MGSNIDS; i++)
312                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
313 }
314
315 /**
316  * Byte swaps config from disk format to cpu
317  *
318  * \param[in,out]       config          sk_keyfile_config to swap
319  */
320 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
321 {
322         int i;
323
324         if (!config)
325                 return;
326
327         config->skc_version = be32toh(config->skc_version);
328         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
329         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
330         config->skc_expire = be32toh(config->skc_expire);
331         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
332         config->skc_prime_bits = be32toh(config->skc_prime_bits);
333
334         for (i = 0; i < MAX_MGSNIDS; i++)
335                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
336 }
337
338 /**
339  * Verifies the on key payload format is valid
340  *
341  * \param[in]   config          sk_keyfile_config
342  *
343  * \return      -1      failure
344  * \return      0       success
345  */
346 int sk_validate_config(const struct sk_keyfile_config *config)
347 {
348         int i;
349
350         if (!config) {
351                 printerr(0, "Null configuration passed\n");
352                 return -1;
353         }
354
355         if (config->skc_version != SK_CONF_VERSION) {
356                 printerr(0, "Invalid version\n");
357                 return -1;
358         }
359
360         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
361                 printerr(0, "Invalid HMAC algorithm\n");
362                 return -1;
363         }
364
365         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
366                 printerr(0, "Invalid crypt algorithm\n");
367                 return -1;
368         }
369
370         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
371                 /* Try to limit key expiration to some reasonable minimum and
372                  * also prevent values over INT_MAX because there appears
373                  * to be a type conversion issue */
374                 printerr(0, "Invalid expiration time should be between %d "
375                          "and %d\n", 60, INT_MAX);
376                 return -1;
377         }
378         if (config->skc_prime_bits % 8 != 0 ||
379             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
380                 printerr(0, "Invalid session key length must be a multiple of 8"
381                          " and less then %d bits\n",
382                          SK_MAX_P_BYTES * 8);
383                 return -1;
384         }
385         if (config->skc_shared_keylen % 8 != 0 ||
386             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
387                 printerr(0, "Invalid shared key max length must be a multiple "
388                          "of 8 and less then %d bits\n",
389                          SK_MAX_KEYLEN_BYTES * 8);
390                 return -1;
391         }
392
393         /* Check for terminating nulls on strings */
394         for (i = 0; i < sizeof(config->skc_fsname) &&
395              config->skc_fsname[i] != '\0';  i++)
396                 ; /* empty loop */
397         if (i == sizeof(config->skc_fsname)) {
398                 printerr(0, "File system name not null terminated\n");
399                 return -1;
400         }
401
402         for (i = 0; i < sizeof(config->skc_nodemap) &&
403              config->skc_nodemap[i] != '\0';  i++)
404                 ; /* empty loop */
405         if (i == sizeof(config->skc_nodemap)) {
406                 printerr(0, "Nodemap name not null terminated\n");
407                 return -1;
408         }
409
410         if (config->skc_type == SK_TYPE_INVALID) {
411                 printerr(0, "Invalid key type\n");
412                 return -1;
413         }
414
415         return 0;
416 }
417
418 /**
419  * Hashes \a string and places the hash in \a hash
420  * at \a hash
421  *
422  * \param[in]           string          Null terminated string to hash
423  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
424  * \param[in,out]       hash            gss_buffer_desc to hold the result
425  *
426  * \return      -1      failure
427  * \return      0       success
428  */
429 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
430                           gss_buffer_desc *hash)
431 {
432         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
433         size_t len = strlen(string);
434         unsigned int hashlen;
435
436         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
437                 goto out_err;
438         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
439                 goto out_err;
440         if (!EVP_DigestUpdate(ctx, string, len))
441                 goto out_err;
442         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
443                 goto out_err;
444
445         EVP_MD_CTX_destroy(ctx);
446         hash->length = hashlen;
447         return 0;
448
449 out_err:
450         EVP_MD_CTX_destroy(ctx);
451         return -1;
452 }
453
454 /**
455  * Hashes \a string and verifies the resulting hash matches the value
456  * in \a current_hash
457  *
458  * \param[in]           string          Null terminated string to hash
459  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
460  * \param[in,out]       current_hash    gss_buffer_desc to compare to
461  *
462  * \return      gss error       failure
463  * \return      GSS_S_COMPLETE  success
464  */
465 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
466                         const gss_buffer_desc *current_hash)
467 {
468         gss_buffer_desc hash;
469         unsigned char hashbuf[EVP_MAX_MD_SIZE];
470
471         hash.value = hashbuf;
472         hash.length = sizeof(hashbuf);
473
474         if (sk_hash_string(string, hash_alg, &hash))
475                 return GSS_S_FAILURE;
476         if (current_hash->length != hash.length)
477                 return GSS_S_DEFECTIVE_TOKEN;
478         if (memcmp(current_hash->value, hash.value, hash.length))
479                 return GSS_S_BAD_SIG;
480
481         return GSS_S_COMPLETE;
482 }
483
484 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
485                                        const char *mgsnid)
486 {
487         lnet_nid_t nid;
488         int i;
489
490         nid = libcfs_str2nid(mgsnid);
491         if (nid == LNET_NID_ANY)
492                 return 0;
493
494         for (i = 0; i < MAX_MGSNIDS; i++)
495                 if  (config->skc_mgsnids[i] == nid)
496                         return 1;
497         return 0;
498 }
499
500 /**
501  * Create an sk_cred structure populated with initial configuration info and the
502  * key.  \a tgt and \a nodemap are used in determining the expected key
503  * description so the key can be found by searching the keyring.
504  * This is done because there is no easy way to pass keys from the mount command
505  * all the way to the request_key call.  In addition any keys can be dynamically
506  * added to the keyrings and still found.  The keyring that needs to be used
507  * must be the session keyring.
508  *
509  * \param[in]   tgt             Target file system
510  * \param[in]   nodemap         Cluster name for the key.  This correlates to
511  *                              the nodemap name and is used by the server side.
512  *                              For the client this will be NULL.
513  * \param[in]   flags           Flags for the credentials
514  *
515  * \return      sk_cred Allocated struct sk_cred on success
516  * \return      NULL    failure
517  */
518 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
519                                const uint32_t flags)
520 {
521         struct sk_keyfile_config *config;
522         struct sk_kernel_ctx *kctx;
523         struct sk_cred *skc = NULL;
524         char description[SK_DESCRIPTION_SIZE + 1];
525         char fsname[MTI_NAME_MAXLEN + 1];
526         const char *mgsnid = NULL;
527         char *ptr;
528         long sk_key;
529         int keylen;
530         int len;
531         int rc;
532
533         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
534                  tgt, nodemap);
535
536         memset(description, 0, sizeof(description));
537         memset(fsname, 0, sizeof(fsname));
538
539         /* extract the file system name from target */
540         ptr = index(tgt, '-');
541         if (!ptr) {
542                 len = strlen(tgt);
543
544                 /* This must be an MGC target */
545                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
546                         printerr(0, "Invalid target name\n");
547                         return NULL;
548                 }
549                 mgsnid = tgt + 3;
550         } else {
551                 len = ptr - tgt;
552         }
553
554         if (len > MTI_NAME_MAXLEN) {
555                 printerr(0, "Invalid target name\n");
556                 return NULL;
557         }
558         memcpy(fsname, tgt, len);
559
560         if (nodemap) {
561                 if (mgsnid)
562                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
563                                       "lustre:MGS:%s", nodemap);
564                 else
565                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
566                                       "lustre:%s:%s", fsname, nodemap);
567         } else {
568                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
569                               fsname);
570         }
571
572         if (rc >= SK_DESCRIPTION_SIZE) {
573                 printerr(0, "Invalid key description\n");
574                 return NULL;
575         }
576
577         /* It may be a good idea to move Lustre keys to the gss_keyring
578          * (lgssc) type so that they expire when Lustre modules are removed.
579          * Unfortunately it can't be done at mount time because the mount
580          * syscall could trigger the Lustre modules to load and until that
581          * point we don't have a lgssc key type.
582          *
583          * TODO: Query the community for a consensus here  */
584         printerr(2, "Searching for key with description: %s\n", description);
585         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
586                                description, 0);
587         if (sk_key == -1) {
588                 printerr(1, "No key found for %s\n", description);
589                 return NULL;
590         }
591
592         keylen = keyctl_read_alloc(sk_key, (void **)&config);
593         if (keylen == -1) {
594                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
595                          strerror(errno));
596                 return NULL;
597         } else if (keylen != sizeof(*config)) {
598                 printerr(0, "Unexpected key size: %d returned for key %ld, "
599                          "expected %zu bytes\n",
600                          keylen, sk_key, sizeof(*config));
601                 goto out_err;
602         }
603
604         sk_config_disk_to_cpu(config);
605
606         if (sk_validate_config(config)) {
607                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
608                 goto out_err;
609         }
610
611         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
612                 printerr(0, "Target name does not match key's MGS NIDs\n");
613                 goto out_err;
614         }
615
616         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
617                 printerr(0, "Target name does not match key's file system\n");
618                 goto out_err;
619         }
620
621         skc = malloc(sizeof(*skc));
622         if (!skc) {
623                 printerr(0, "Failed to allocate memory for sk_cred\n");
624                 goto out_err;
625         }
626
627         /* this initializes all gss_buffer_desc to empty as well */
628         memset(skc, 0, sizeof(*skc));
629
630         skc->sc_flags = flags;
631         skc->sc_tgt.length = strlen(tgt) + 1;
632         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
633         if (!skc->sc_tgt.value) {
634                 printerr(0, "Failed to allocate memory for target\n");
635                 goto out_err;
636         }
637         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
638
639         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
640         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
641         if (!skc->sc_nodemap_hash.value) {
642                 printerr(0, "Failed to allocate memory for nodemap hash\n");
643                 goto out_err;
644         }
645
646         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
647                            &skc->sc_nodemap_hash)) {
648                 printerr(0, "Failed to generate hash for nodemap name\n");
649                 goto out_err;
650         }
651
652         kctx = &skc->sc_kctx;
653         kctx->skc_version = config->skc_version;
654         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
655         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
656         kctx->skc_expire = config->skc_expire;
657
658         /* key payload format is in bits, convert to bytes */
659         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
660         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
661         if (!kctx->skc_shared_key.value) {
662                 printerr(0, "Failed to allocate memory for shared key\n");
663                 goto out_err;
664         }
665         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
666                kctx->skc_shared_key.length);
667
668         skc->sc_p.length = config->skc_prime_bits / 8;
669         skc->sc_p.value = malloc(skc->sc_p.length);
670         if (!skc->sc_p.value) {
671                 printerr(0, "Failed to allocate p\n");
672                 goto out_err;
673         }
674         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
675
676         free(config);
677
678         return skc;
679
680 out_err:
681         sk_free_cred(skc);
682
683         free(config);
684         return NULL;
685 }
686
687 #define SK_GENERATOR 2
688 #define DH_NUMBER_ITERATIONS_FOR_PRIME 64
689
690 /* OpenSSL 1.1.1c increased the number of rounds used for Miller-Rabin testing
691  * of the prime provided as input parameter to DH_check(). This makes the check
692  * roughly x10 longer, and causes request timeouts when an SSK flavor is being
693  * used.
694  *
695  * Instead, use a dynamic number Miller-Rabin rounds based on the speed of the
696  * check on the current system, evaluated when the lsvcgssd daemon starts, but
697  * at least as many as OpenSSL 1.1.1b used for the same key size. If default
698  * DH_check() duration is OK, use it directly instead of limiting the rounds.
699  *
700  * If \a num_rounds == 0, we just call original DH_check() directly.
701  */
702 static bool sk_is_dh_valid(const DH *dh, int num_rounds)
703 {
704         const BIGNUM *p, *g;
705         BN_ULONG word;
706         BN_CTX *ctx;
707         BIGNUM *r;
708         bool valid = false;
709         int rc;
710
711         if (num_rounds == 0) {
712                 int codes = 0;
713
714                 rc = DH_check(dh, &codes);
715                 if (rc != 1 || codes) {
716                         printerr(0, "DH_check(0) failed: codes=%#x: rc=%d\n",
717                                  codes, rc);
718                         return false;
719                 }
720                 return true;
721         }
722
723         DH_get0_pqg(dh, &p, NULL, &g);
724
725         if (!BN_is_word(g, SK_GENERATOR)) {
726                 printerr(0, "%s: Diffie-Hellman generator is not %u\n",
727                          program_invocation_short_name, SK_GENERATOR);
728                 return false;
729         }
730
731         word = BN_mod_word(p, 24);
732         if (word != 11) {
733                 printerr(0, "%s: Diffie-Hellman prime modulo=%lu unsuitable\n",
734                          program_invocation_short_name, word);
735                 return false;
736         }
737
738         ctx = BN_CTX_new();
739         if (ctx == NULL) {
740                 printerr(0, "%s: Diffie-Hellman error allocating context\n",
741                          program_invocation_short_name);
742                 return false;
743         }
744         BN_CTX_start(ctx);
745         r = BN_CTX_get(ctx); /* must be called before "ctx" used elsewhere */
746
747         rc = BN_is_prime_ex(p, num_rounds, ctx, NULL);
748         if (rc == 0)
749                 printerr(0, "%s: Diffie-Hellman 'p' not prime in %u rounds\n",
750                          program_invocation_short_name, num_rounds);
751         if (rc <= 0)
752                 goto out_free;
753
754         if (!BN_rshift1(r, p)) {
755                 printerr(0, "%s: error shifting BigNum 'r' by 'p'\n",
756                          program_invocation_short_name);
757                 goto out_free;
758         }
759         rc = BN_is_prime_ex(r, num_rounds, ctx, NULL);
760         if (rc == 0)
761                 printerr(0, "%s: Diffie-Hellman 'r' not prime in %u rounds\n",
762                          program_invocation_short_name, num_rounds);
763         if (rc <= 0)
764                 goto out_free;
765
766         valid = true;
767
768 out_free:
769         BN_CTX_end(ctx);
770         BN_CTX_free(ctx);
771
772         return valid;
773 }
774
775 #define VALUE_LENGTH 256
776 static unsigned char test_prime[VALUE_LENGTH] =
777         "\xf7\xfa\x49\xd8\xec\xb1\x3b\xff\x26\x10\x3f\xc5\x3a\xc5\xcc\x40"
778         "\x4f\xbf\x92\xe1\x8b\x83\xe7\xa2\xba\x0f\x51\x5a\x91\x48\xe0\xa3"
779         "\xf1\x4d\xbc\xbb\x8a\x28\x14\xac\x02\x23\x76\x42\x17\x4d\x3c\xdc"
780         "\x5e\x4f\x80\x1f\xd7\x54\x1c\x50\xac\x3b\x28\x68\x8d\x71\x41\x7f"
781         "\xa7\x1c\x2f\x22\xd3\xa8\x91\xb2\x64\xb6\x84\xa6\xcf\x06\x16\x91"
782         "\x2f\xb8\xb4\x42\x1d\x3a\x4e\x3a\x0c\x7f\x04\x69\x78\xb5\x8f\x92"
783         "\x07\x89\xac\x24\x06\x53\x2c\x23\xec\xaa\x5c\xb4\x7b\x49\xbc\xf4"
784         "\x90\x67\x71\x9c\x24\x2c\x1d\x8d\x76\xc8\x85\x4e\x19\xf1\xf9\x33"
785         "\x45\xbd\x9f\x7d\x0a\x08\x8c\x22\xcc\x35\xf3\x5b\xab\x3f\x24\x9d"
786         "\x61\x70\x86\xbb\xbe\xd8\xb0\xf8\x34\xfa\xeb\x5b\x8e\xf2\x62\x23"
787         "\xd1\xfb\xbb\xb8\x21\x71\x1e\x39\x39\x59\xe0\x82\x98\x41\x84\x40"
788         "\x1f\xd3\x9b\xa3\x73\xdb\xec\xe0\xc0\xde\x2d\x1c\xea\x43\x40\x93"
789         "\x98\x38\x03\x36\x1e\xe1\xe7\x39\x7b\x35\x92\x4a\x51\xa5\x91\x63"
790         "\xd5\x31\x98\x3d\x89\x27\x6f\xcc\x69\xff\xbe\x31\x13\xdc\x2f\x72"
791         "\x2d\xab\x6a\xb7\x13\xd3\x47\xda\xaa\xf3\x3c\xa0\xfd\xaa\x0f\x02"
792         "\x96\x81\x1a\x26\xe8\xf7\x25\x65\x33\x78\xd9\x6b\x6d\xb0\xd9\xfb";
793
794 /**
795  * Measure time taken by prime testing routine for a 2048 bit long prime,
796  * depending on the number of check rounds.
797  *
798  * \param[in]   usec_check_max    max time allowed for DH_check completion
799  *
800  * \retval      max number of rounds to keep prime testing under usec_check_max
801  *              return 0 if we should use the default DH_check rounds
802  */
803 int sk_speedtest_dh_valid(unsigned int usec_check_max)
804 {
805         DH *dh;
806         BIGNUM *p, *g;
807         int num_rounds, prev_rounds = 0;
808
809         dh = DH_new();
810         if (!dh)
811                 return 0;
812
813         p = BN_bin2bn(test_prime, VALUE_LENGTH, NULL);
814         if (!p)
815                 goto free_dh;
816
817         g = BN_new();
818         if (!g)
819                 goto free_p;
820
821         if (!BN_set_word(g, SK_GENERATOR))
822                 goto free_g;
823
824         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
825         if (!DH_set0_pqg(dh, p, NULL, g)) {
826         free_g:
827                 BN_free(g);
828         free_p:
829                 BN_free(p);
830                 goto free_dh;
831         }
832
833         for (num_rounds = 0;
834              num_rounds <= DH_NUMBER_ITERATIONS_FOR_PRIME;
835              num_rounds += (num_rounds <= 4 ? 4 : 8)) {
836                 unsigned int usec_this;
837                 int j;
838
839                 /* get max duration of 4 runs at current number of rounds */
840                 usec_this = 0;
841                 for (j = 0; j < 4; j++) {
842                         struct timeval now, prev;
843                         unsigned int usec_curr;
844
845                         gettimeofday(&prev, NULL);
846                         if (!sk_is_dh_valid(dh, num_rounds)) {
847                                 /* if test_prime is found bad, use default */
848                                 prev_rounds = 0;
849                                 goto free_dh;
850                         }
851                         gettimeofday(&now, NULL);
852                         usec_curr = (now.tv_sec - prev.tv_sec) * 1000000 +
853                                     now.tv_usec - prev.tv_usec;
854                         if (usec_curr > usec_this)
855                                 usec_this = usec_curr;
856                 }
857                 printerr(2, "%s: %d rounds: %d usec\n",
858                          program_invocation_short_name, num_rounds, usec_this);
859                 if (num_rounds == 0) {
860                         if (usec_this <= usec_check_max)
861                         /* using original check rounds as implemented in
862                          * DH_check() took less time than the max allowed,
863                          * so just use original DH_check()
864                          */
865                                 break;
866                 } else if (usec_this >= usec_check_max) {
867                         break;
868                 }
869                 prev_rounds = num_rounds;
870         }
871
872 free_dh:
873         DH_free(dh);
874
875         return prev_rounds;
876 }
877
878 /**
879  * Populates the DH parameters for the DHKE
880  *
881  * \param[in,out]       skc             Shared key credentials structure to
882  *                                      populate with DH parameters
883  *
884  * \retval      GSS_S_COMPLETE  success
885  * \retval      GSS_S_FAILURE   failure
886  */
887 uint32_t sk_gen_params(struct sk_cred *skc, int num_rounds)
888 {
889         uint32_t random;
890         BIGNUM *p, *g;
891         const BIGNUM *pub_key;
892
893         /* Random value used by both the request and response as part of the
894          * key binding material.  This also should ensure we have unqiue
895          * tokens that are sent to the remote server which is important because
896          * the token is hashed for the sunrpc cache lookups and a failure there
897          * would cause connection attempts to fail indefinitely due to the large
898          * timeout value on the server side */
899         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
900                 printerr(0, "Failed to get data for random parameter: %s\n",
901                          ERR_error_string(ERR_get_error(), NULL));
902                 return GSS_S_FAILURE;
903         }
904
905         /* The random value will always be used in byte range operations
906          * so we keep it as big endian from this point on */
907         skc->sc_kctx.skc_host_random = random;
908
909         /* Populate DH parameters */
910         skc->sc_params = DH_new();
911         if (!skc->sc_params) {
912                 printerr(0, "Failed to allocate DH\n");
913                 return GSS_S_FAILURE;
914         }
915
916         p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
917         if (!p) {
918                 printerr(0, "Failed to convert binary to BIGNUM\n");
919                 return GSS_S_FAILURE;
920         }
921
922         /* We use a static generator for shared key */
923         g = BN_new();
924         if (!g) {
925                 printerr(0, "Failed to allocate new BIGNUM\n");
926                 return GSS_S_FAILURE;
927         }
928         if (BN_set_word(g, SK_GENERATOR) != 1) {
929                 printerr(0, "Failed to set g value for DH params\n");
930                 return GSS_S_FAILURE;
931         }
932
933         if (!DH_set0_pqg(skc->sc_params, p, NULL, g)) {
934                 printerr(0, "Failed to set pqg\n");
935                 return GSS_S_FAILURE;
936         }
937
938         /* Verify that we have a safe prime and valid generator */
939         if (!sk_is_dh_valid(skc->sc_params, num_rounds))
940                 return GSS_S_FAILURE;
941
942         if (DH_generate_key(skc->sc_params) != 1) {
943                 printerr(0, "Failed to generate public DH key: %s\n",
944                          ERR_error_string(ERR_get_error(), NULL));
945                 return GSS_S_FAILURE;
946         }
947
948         DH_get0_key(skc->sc_params, &pub_key, NULL);
949         skc->sc_pub_key.length = BN_num_bytes(pub_key);
950         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
951         if (!skc->sc_pub_key.value) {
952                 printerr(0, "Failed to allocate memory for public key\n");
953                 return GSS_S_FAILURE;
954         }
955
956         BN_bn2bin(pub_key, skc->sc_pub_key.value);
957
958         return GSS_S_COMPLETE;
959 }
960
961 /**
962  * Convert SK hash algorithm into openssl message digest
963  *
964  * \param[in,out]       alg             SK hash algorithm
965  *
966  * \retval              EVP_MD
967  */
968 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
969 {
970         switch (alg) {
971         case CFS_HASH_ALG_SHA256:
972                 return EVP_sha256();
973         case CFS_HASH_ALG_SHA512:
974                 return EVP_sha512();
975         default:
976                 return EVP_md_null();
977         }
978 }
979
980 /**
981  * Signs (via HMAC) the parameters used only in the key initialization protocol.
982  *
983  * \param[in]           key             Key to use for HMAC
984  * \param[in]           bufs            Array of gss_buffer_desc to generate
985  *                                      HMAC for
986  * \param[in]           numbufs         Number of buffers in array
987  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
988  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
989  *                                      in this gss_buffer_desc.  Caller must
990  *                                      free this.
991  *
992  * \retval      0       success
993  * \retval      -1      failure
994  */
995 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
996                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
997 {
998         HMAC_CTX *hctx;
999         unsigned int hashlen = EVP_MD_size(hash_alg);
1000         int i;
1001         int rc = -1;
1002
1003         if (hash_alg == EVP_md_null()) {
1004                 printerr(0, "Invalid hash algorithm\n");
1005                 return -1;
1006         }
1007
1008         hctx = HMAC_CTX_new();
1009
1010         hmac->length = hashlen;
1011         hmac->value = malloc(hashlen);
1012         if (!hmac->value) {
1013                 printerr(0, "Failed to allocate memory for HMAC\n");
1014                 goto out;
1015         }
1016
1017         if (HMAC_Init_ex(hctx, key->value, key->length, hash_alg, NULL) != 1) {
1018                 printerr(0, "Failed to init HMAC\n");
1019                 goto out;
1020         }
1021
1022         for (i = 0; i < numbufs; i++) {
1023                 if (HMAC_Update(hctx, bufs[i].value, bufs[i].length) != 1) {
1024                         printerr(0, "Failed to update HMAC\n");
1025                         goto out;
1026                 }
1027         }
1028
1029         /* The result gets populated in hmac */
1030         if (HMAC_Final(hctx, hmac->value, &hashlen) != 1) {
1031                 printerr(0, "Failed to finalize HMAC\n");
1032                 goto out;
1033         }
1034
1035         if (hmac->length != hashlen) {
1036                 printerr(0, "HMAC size does not match expected\n");
1037                 goto out;
1038         }
1039
1040         rc = 0;
1041 out:
1042         HMAC_CTX_free(hctx);
1043         return rc;
1044 }
1045
1046 /**
1047  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
1048  * and verifies against \a hmac.
1049  *
1050  * \param[in]   skc             Shared key credentials
1051  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
1052  * \param[in]   numbufs         Number of buffers in array
1053  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
1054  * \param[in]   hmac            HMAC to verify against
1055  *
1056  * \retval      GSS_S_COMPLETE  success (match)
1057  * \retval      gss error       failure
1058  */
1059 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
1060                         const int numbufs, const EVP_MD *hash_alg,
1061                         gss_buffer_desc *hmac)
1062 {
1063         gss_buffer_desc bufs_hmac;
1064         int rc;
1065
1066         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
1067                          &bufs_hmac)) {
1068                 printerr(0, "Failed to sign buffers to verify HMAC\n");
1069                 if (bufs_hmac.value)
1070                         free(bufs_hmac.value);
1071                 return GSS_S_FAILURE;
1072         }
1073
1074         if (hmac->length != bufs_hmac.length) {
1075                 printerr(0, "Invalid HMAC size\n");
1076                 free(bufs_hmac.value);
1077                 return GSS_S_BAD_SIG;
1078         }
1079
1080         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
1081         free(bufs_hmac.value);
1082
1083         if (rc)
1084                 return GSS_S_BAD_SIG;
1085
1086         return GSS_S_COMPLETE;
1087 }
1088
1089 /**
1090  * Cleanup an sk_cred freeing any resources
1091  *
1092  * \param[in,out]       skc     Shared key credentials to free
1093  */
1094 void sk_free_cred(struct sk_cred *skc)
1095 {
1096         if (!skc)
1097                 return;
1098
1099         if (skc->sc_p.value)
1100                 free(skc->sc_p.value);
1101         if (skc->sc_pub_key.value)
1102                 free(skc->sc_pub_key.value);
1103         if (skc->sc_tgt.value)
1104                 free(skc->sc_tgt.value);
1105         if (skc->sc_nodemap_hash.value)
1106                 free(skc->sc_nodemap_hash.value);
1107         if (skc->sc_hmac.value)
1108                 free(skc->sc_hmac.value);
1109
1110         /* Overwrite keys and IV before freeing */
1111         if (skc->sc_dh_shared_key.value) {
1112                 memset(skc->sc_dh_shared_key.value, 0,
1113                        skc->sc_dh_shared_key.length);
1114                 free(skc->sc_dh_shared_key.value);
1115         }
1116         if (skc->sc_kctx.skc_hmac_key.value) {
1117                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
1118                        skc->sc_kctx.skc_hmac_key.length);
1119                 free(skc->sc_kctx.skc_hmac_key.value);
1120         }
1121         if (skc->sc_kctx.skc_encrypt_key.value) {
1122                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
1123                        skc->sc_kctx.skc_encrypt_key.length);
1124                 free(skc->sc_kctx.skc_encrypt_key.value);
1125         }
1126         if (skc->sc_kctx.skc_shared_key.value) {
1127                 memset(skc->sc_kctx.skc_shared_key.value, 0,
1128                        skc->sc_kctx.skc_shared_key.length);
1129                 free(skc->sc_kctx.skc_shared_key.value);
1130         }
1131         if (skc->sc_kctx.skc_session_key.value) {
1132                 memset(skc->sc_kctx.skc_session_key.value, 0,
1133                        skc->sc_kctx.skc_session_key.length);
1134                 free(skc->sc_kctx.skc_session_key.value);
1135         }
1136
1137         if (skc->sc_params)
1138                 DH_free(skc->sc_params);
1139
1140         free(skc);
1141         skc = NULL;
1142 }
1143
1144 /* This function handles key derivation using the hash algorithm specified in
1145  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
1146  * \a origin_key to produce a \a derived_key.  The first element of the
1147  * key_binding_bufs array is reserved for the counter used in the KDF.  The
1148  * derived key in \a derived_key could differ in size from \a origin_key and
1149  * must be populated with the expected size and a valid buffer to hold the
1150  * contents.
1151  *
1152  * If the derived key size is greater than the HMAC algorithm size it will be
1153  * a done using several iterations of a counter and the key binding bufs.
1154  *
1155  * If the size is smaller it will take copy the first N bytes necessary to
1156  * fill the derived key. */
1157 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
1158            gss_buffer_desc *key_binding_bufs, int numbufs,
1159            enum cfs_crypto_hash_alg hmac_alg)
1160 {
1161         size_t remain;
1162         size_t bytes;
1163         uint32_t counter;
1164         char *keydata;
1165         gss_buffer_desc tmp_hash;
1166         int i;
1167         int rc;
1168
1169         if (numbufs < 1)
1170                 return -EINVAL;
1171
1172         /* Use a counter as the first buffer followed by the key binding
1173          * buffers in the event we need more than one a single cycle to
1174          * produced a symmetric key large enough in size */
1175         key_binding_bufs[0].value = &counter;
1176         key_binding_bufs[0].length = sizeof(counter);
1177
1178         remain = derived_key->length;
1179         keydata = derived_key->value;
1180         i = 0;
1181         while (remain > 0) {
1182                 counter = htobe32(i++);
1183                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1184                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1185                 if (rc) {
1186                         if (tmp_hash.value)
1187                                 free(tmp_hash.value);
1188                         return rc;
1189                 }
1190
1191                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
1192                         free(tmp_hash.value);
1193                         return -EINVAL;
1194                 }
1195
1196                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1197                 memcpy(keydata, tmp_hash.value, bytes);
1198                 free(tmp_hash.value);
1199                 remain -= bytes;
1200                 keydata += bytes;
1201         }
1202
1203         return 0;
1204 }
1205
1206 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1207  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1208  * (Sep 2014) Section 5.5.1
1209  *
1210  * \param[in,out]       skc             Shared key credentials structure with
1211  *
1212  * \return      -1              failure
1213  * \return      0               success
1214  */
1215 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1216                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1217 {
1218         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1219         gss_buffer_desc *session_key = &kctx->skc_session_key;
1220         gss_buffer_desc bufs[5];
1221         enum cfs_crypto_crypt_alg crypt_alg;
1222         int rc = -1;
1223
1224         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1225         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1226         session_key->value = malloc(session_key->length);
1227         if (!session_key->value) {
1228                 printerr(0, "Failed to allocate memory for session key\n");
1229                 return rc;
1230         }
1231
1232         /* Key binding info ordering
1233          * 1. Reserved for counter
1234          * 1. DH shared key
1235          * 2. Client's NIDs
1236          * 3. Client's token
1237          * 4. Server's token */
1238         bufs[0].value = NULL;
1239         bufs[0].length = 0;
1240         bufs[1] = skc->sc_dh_shared_key;
1241         bufs[2].value = &client_nid;
1242         bufs[2].length = sizeof(client_nid);
1243         bufs[3] = *client_token;
1244         bufs[4] = *server_token;
1245
1246         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1247                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1248 }
1249
1250 /* Uses the session key to create an HMAC key and encryption key.  In
1251  * integrity mode the session key used to generate the HMAC key uses
1252  * session information which is available on the wire but by creating
1253  * a session based HMAC key we can prevent potential replay as both the
1254  * client and server have random numbers used as part of the key creation.
1255  *
1256  * The keys used for integrity and privacy are formulated as below using
1257  * the session key that is the output of the key derivation function.  The
1258  * HMAC algorithm is determined by the shared key algorithm selected in the
1259  * key file.
1260  *
1261  * For ski mode:
1262  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1263  *
1264  * For skpi mode:
1265  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1266  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1267  *
1268  * \param[in,out]       skc             Shared key credentials structure with
1269  *
1270  * \return      -1              failure
1271  * \return      0               success
1272  */
1273 int sk_compute_keys(struct sk_cred *skc)
1274 {
1275         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1276         gss_buffer_desc *session_key = &kctx->skc_session_key;
1277         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1278         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1279         enum cfs_crypto_hash_alg hmac_alg;
1280         enum cfs_crypto_crypt_alg crypt_alg;
1281         char *encrypt = "Encrypt";
1282         char *integrity = "Integrity";
1283         int rc;
1284
1285         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1286         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1287         hmac_key->value = malloc(hmac_key->length);
1288         if (!hmac_key->value)
1289                 return -ENOMEM;
1290
1291         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1292                                session_key->length, SK_PBKDF2_ITERATIONS,
1293                                sk_hash_to_evp_md(hmac_alg),
1294                                hmac_key->length, hmac_key->value);
1295         if (rc == 0)
1296                 return -EINVAL;
1297
1298         /* Encryption key is only populated in privacy mode */
1299         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1300                 return 0;
1301
1302         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1303         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1304         encrypt_key->value = malloc(encrypt_key->length);
1305         if (!encrypt_key->value)
1306                 return -ENOMEM;
1307
1308         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1309                                session_key->length, SK_PBKDF2_ITERATIONS,
1310                                sk_hash_to_evp_md(hmac_alg),
1311                                encrypt_key->length, encrypt_key->value);
1312         if (rc == 0)
1313                 return -EINVAL;
1314
1315         return 0;
1316 }
1317
1318 /**
1319  * Computes a session key based on the DH parameters from the host and its peer
1320  *
1321  * \param[in,out]       skc             Shared key credentials structure with
1322  *                                      the session key populated with the
1323  *                                      compute key
1324  * \param[in]           pub_key         Public key returned from peer in
1325  *                                      gss_buffer_desc
1326  * \return      gss error               failure
1327  * \return      GSS_S_COMPLETE          success
1328  */
1329 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1330 {
1331         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1332         BIGNUM *remote_pub_key;
1333         int status;
1334         uint32_t rc = GSS_S_FAILURE;
1335
1336         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1337         if (!remote_pub_key) {
1338                 printerr(0, "Failed to convert binary to BIGNUM\n");
1339                 return rc;
1340         }
1341
1342         dh_shared->length = DH_size(skc->sc_params);
1343         dh_shared->value = malloc(dh_shared->length);
1344         if (!dh_shared->value) {
1345                 printerr(0, "Failed to allocate memory for computed shared "
1346                          "secret key\n");
1347                 goto out_err;
1348         }
1349
1350         /* This compute the shared key from the DHKE */
1351         status = DH_compute_key(dh_shared->value, remote_pub_key,
1352                                 skc->sc_params);
1353         if (status == -1) {
1354                 printerr(0, "DH_compute_key() failed: %s\n",
1355                          ERR_error_string(ERR_get_error(), NULL));
1356                 goto out_err;
1357         } else if (status < dh_shared->length) {
1358                 /* there is around 1 chance out of 256 that the returned
1359                  * shared key is shorter than expected
1360                  */
1361                 if (status >= dh_shared->length - 2) {
1362                         int shift = dh_shared->length - status;
1363                         /* if the key is short by only 1 or 2 bytes, just
1364                          * prepend it with 0s
1365                          */
1366                         memmove((void *)(dh_shared->value + shift),
1367                                 dh_shared->value, status);
1368                         memset(dh_shared->value, 0, shift);
1369                 } else {
1370                         /* if the key is really too short, return GSS_S_BAD_QOP
1371                          * so that the caller can retry to generate
1372                          */
1373                         printerr(0, "DH_compute_key() returned a short key of %d bytes, expected: %zu\n",
1374                                  status, dh_shared->length);
1375                         rc = GSS_S_BAD_QOP;
1376                         goto out_err;
1377                 }
1378         }
1379
1380         rc = GSS_S_COMPLETE;
1381
1382 out_err:
1383         BN_free(remote_pub_key);
1384         return rc;
1385 }
1386
1387 /**
1388  * Creates a serialized buffer for the kernel in the order of struct
1389  * sk_kernel_ctx.
1390  *
1391  * \param[in,out]       skc             Shared key credentials structure
1392  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1393  *                                      Caller must free this buffer.
1394  *
1395  * \return      0       success
1396  * \return      -1      failure
1397  */
1398 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1399 {
1400         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1401         char *p, *end;
1402         size_t bufsize;
1403
1404         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1405                   kctx->skc_encrypt_key.length;
1406
1407         ctx_token->value = malloc(bufsize);
1408         if (!ctx_token->value)
1409                 return -1;
1410         ctx_token->length = bufsize;
1411
1412         p = ctx_token->value;
1413         end = p + ctx_token->length;
1414
1415         if (WRITE_BYTES(&p, end, kctx->skc_version))
1416                 return -1;
1417         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1418                 return -1;
1419         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1420                 return -1;
1421         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1422                 return -1;
1423         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1424                 return -1;
1425         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1426                 return -1;
1427         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1428                 return -1;
1429         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1430                 return -1;
1431
1432         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1433
1434         return 0;
1435 }
1436
1437 /**
1438  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1439  * up to \a numbufs.  Memory is allocated for each value and length
1440  * will be populated with the length
1441  *
1442  * \param[in,out]       bufs    Array of gss_buffer_descs
1443  * \param[in,out]       numbufs number of gss_buffer_desc in array
1444  * \param[in]           ns      netstring to decode
1445  *
1446  * \return      buffers populated       success
1447  * \return      -1                      failure
1448  */
1449 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1450 {
1451         char *ptr = ns->value;
1452         size_t remain = ns->length;
1453         unsigned int size;
1454         int digits;
1455         int sep;
1456         int rc;
1457         int i;
1458
1459         for (i = 0; i < numbufs; i++) {
1460                 /* read the size of first buffer */
1461                 rc = sscanf(ptr, "%9u", &size);
1462                 if (rc < 1)
1463                         goto out_err;
1464                 digits = (size) ? ceil(log10(size + 1)) : 1;
1465
1466                 /* sep of current string */
1467                 sep = size + digits + 2;
1468
1469                 /* check to make sure it's valid */
1470                 if (remain < sep || ptr[digits] != ':' ||
1471                     ptr[sep - 1] != ',')
1472                         goto out_err;
1473
1474                 bufs[i].length = size;
1475                 if (size == 0) {
1476                         bufs[i].value = NULL;
1477                 } else {
1478                         bufs[i].value = malloc(size);
1479                         if (!bufs[i].value)
1480                                 goto out_err;
1481                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1482                 }
1483
1484                 remain -= sep;
1485                 ptr += sep;
1486         }
1487
1488         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1489         return i;
1490
1491 out_err:
1492         while (i-- > 0) {
1493                 if (bufs[i].value)
1494                         free(bufs[i].value);
1495                 bufs[i].length = 0;
1496         }
1497         return -1;
1498 }
1499
1500 /**
1501  * Creates a netstring in a gss_buffer_desc that consists of all
1502  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1503  * as binary as it can contain null characters.
1504  *
1505  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1506  * \param[in]           numbufs         Number of buffers in array
1507  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1508  *                                      netstring
1509  *
1510  * \return      -1      failure
1511  * \return      0       success
1512  */
1513 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1514                         gss_buffer_desc *ns)
1515 {
1516         unsigned char *ptr;
1517         int size = 0;
1518         int rc;
1519         int i;
1520
1521         /* size of string in decimal, string size, colon, and comma */
1522         for (i = 0; i < numbufs; i++) {
1523
1524                 if (bufs[i].length == 0)
1525                         size += 3;
1526                 else
1527                         size += ceil(log10(bufs[i].length + 1)) +
1528                                 bufs[i].length + 2;
1529         }
1530
1531         ns->length = size;
1532         ns->value = malloc(ns->length);
1533         if (!ns->value) {
1534                 ns->length = 0;
1535                 return -1;
1536         }
1537
1538         ptr = ns->value;
1539         for (i = 0; i < numbufs; i++) {
1540                 /* size */
1541                 rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
1542                 ptr += rc;
1543
1544                 /* contents */
1545                 memcpy(ptr, bufs[i].value, bufs[i].length);
1546                 ptr += bufs[i].length;
1547
1548                 /* delimeter */
1549                 *ptr++ = ',';
1550
1551                 size -= bufs[i].length + rc + 1;
1552
1553                 /* should not happen */
1554                 if (size < 0)
1555                         abort();
1556         }
1557
1558         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1559         return 0;
1560 }