Whamcloud - gitweb
LU-17744 ldiskfs: mballoc stats fixes
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #ifdef HAVE_OPENSSL_EVP_PKEY
41 #include <openssl/param_build.h>
42 #endif
43 #include <sys/types.h>
44 #include <sys/stat.h>
45 #include <libcfs/util/string.h>
46 #include <sys/time.h>
47 #include <signal.h>
48
49 #include "sk_utils.h"
50 #include "write_bytes.h"
51
52 #define SK_PBKDF2_ITERATIONS 10000
53
54 #ifdef _NEW_BUILD_
55 # include "lgss_utils.h"
56 #else
57 # include "gss_util.h"
58 # include "gss_oids.h"
59 # include "err_util.h"
60 #endif
61
62 #ifdef _ERR_UTIL_H_
63 /**
64  * Initializes logging
65  * \param[in]   program         Program name to output
66  * \param[in]   verbose         Verbose flag
67  * \param[in]   fg              Whether or not to run in foreground
68  *
69  */
70 void sk_init_logging(char *program, int verbose, int fg)
71 {
72         initerr(program, verbose, fg);
73 }
74 #endif
75
76 /**
77  * Loads the key from \a filename and returns the struct sk_keyfile_config.
78  * It should be freed by the caller.
79  *
80  * \param[in]   filename                Disk or key payload data
81  *
82  * \return      sk_keyfile_config       sucess
83  * \return      NULL                    failure
84  */
85 struct sk_keyfile_config *sk_read_file(char *filename)
86 {
87         struct sk_keyfile_config *config;
88         char *ptr;
89         size_t rc;
90         size_t remain;
91         int fd;
92
93         config = malloc(sizeof(*config));
94         if (!config) {
95                 printerr(0, "Failed to allocate memory for config\n");
96                 return NULL;
97         }
98
99         /* allow standard input override */
100         if (strcmp(filename, "-") == 0)
101                 fd = STDIN_FILENO;
102         else
103                 fd = open(filename, O_RDONLY);
104
105         if (fd == -1) {
106                 printerr(0, "Error opening key file '%s': %s\n", filename,
107                          strerror(errno));
108                 goto out_free;
109         } else if (fd != STDIN_FILENO) {
110                 struct stat st;
111
112                 rc = fstat(fd, &st);
113                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
114                         fprintf(stderr, "warning: "
115                                 "secret key '%s' has insecure file mode %#o\n",
116                                 filename, st.st_mode);
117         }
118
119         ptr = (char *)config;
120         remain = sizeof(*config);
121         while (remain > 0) {
122                 rc = read(fd, ptr, remain);
123                 if (rc == -1) {
124                         if (errno == EINTR)
125                                 continue;
126                         printerr(0, "read() failed on %s: %s\n", filename,
127                                  strerror(errno));
128                         goto out_close;
129                 } else if (rc == 0) {
130                         printerr(0, "File %s does not have a complete key\n",
131                                  filename);
132                         goto out_close;
133                 }
134                 ptr += rc;
135                 remain -= rc;
136         }
137
138         if (fd != STDIN_FILENO)
139                 close(fd);
140         sk_config_disk_to_cpu(config);
141         return config;
142
143 out_close:
144         close(fd);
145 out_free:
146         free(config);
147         return NULL;
148 }
149
150 /**
151  * Checks if a key matching \a description is found in the keyring for
152  * logging purposes and then attempts to load \a payload of \a psize into a key
153  * with \a description.
154  *
155  * \param[in]   payload         Key payload
156  * \param[in]   psize           Payload size
157  * \param[in]   description     Description used for key in keyring
158  *
159  * \return      0       sucess
160  * \return      -1      failure
161  */
162 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
163                                 const char *description)
164 {
165         struct sk_keyfile_config payload;
166         key_serial_t key;
167
168         memcpy(&payload, skc, sizeof(*skc));
169
170         /* In the keyring use the disk layout so keyctl pipe can be used */
171         sk_config_cpu_to_disk(&payload);
172
173         /* Check to see if a key is already loaded matching description */
174         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
175         if (key != -1)
176                 printerr(2, "Key %d found in session keyring, replacing\n",
177                          key);
178
179         key = add_key("user", description, &payload, sizeof(payload),
180                       KEY_SPEC_USER_KEYRING);
181         if (key != -1) {
182                 key_perm_t perm = KEY_POS_ALL | KEY_USR_ALL |
183                         KEY_GRP_ALL | KEY_OTH_ALL;
184
185                 if (keyctl_setperm(key, perm) < 0)
186                         printerr(2, "Failed to set perm 0x%x on key %d\n",
187                                  perm, key);
188                 printerr(2, "Added key %d with description %s\n", key,
189                          description);
190         } else {
191                 printerr(0, "Failed to add key with %s\n", description);
192         }
193
194         return key;
195 }
196
197 /**
198  * Reads the key from \a path, verifies it and loads into the session keyring
199  * using a description determined by the the \a type.  Existing keys with the
200  * same description are replaced.
201  *
202  * \param[in]   path    Path to key file
203  * \param[in]   type    Type of key to load which determines the description
204  *
205  * \return      0       sucess
206  * \return      -1      failure
207  */
208 int sk_load_keyfile(char *path)
209 {
210         struct sk_keyfile_config *config;
211         char description[SK_DESCRIPTION_SIZE + 1];
212         struct stat buf;
213         int i;
214         int rc;
215         int rc2 = -1;
216
217         rc = stat(path, &buf);
218         if (rc == -1) {
219                 printerr(0, "stat() failed for file %s: %s\n", path,
220                          strerror(errno));
221                 return rc2;
222         }
223
224         config = sk_read_file(path);
225         if (!config)
226                 return rc2;
227
228         /* Similar to ssh, require adequate care of key files */
229         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
230                 printerr(0, "Shared key files must be read/writeable only by "
231                          "owner\n");
232                 return -1;
233         }
234
235         if (sk_validate_config(config))
236                 goto out;
237
238         /* The server side can have multiple key files per file system so
239          * the nodemap name is appended to the key description to uniquely
240          * identify it */
241         if (config->skc_type & SK_TYPE_MGS) {
242                 /* Any key can be an MGS key as long as we are told to use it */
243                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
244                               config->skc_nodemap);
245                 if (rc >= SK_DESCRIPTION_SIZE)
246                         goto out;
247                 if (sk_load_key(config, description) == -1)
248                         goto out;
249         }
250         if (config->skc_type & SK_TYPE_SERVER) {
251                 /* Server keys need to have the file system name in the key */
252                 if (config->skc_fsname[0] == '\0') {
253                         printerr(0, "Key configuration has no file system "
254                                  "attribute.  Can't load as server type\n");
255                         goto out;
256                 }
257                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
258                               config->skc_fsname, config->skc_nodemap);
259                 if (rc >= SK_DESCRIPTION_SIZE)
260                         goto out;
261                 if (sk_load_key(config, description) == -1)
262                         goto out;
263         }
264         if (config->skc_type & SK_TYPE_CLIENT) {
265                 /* Load client file system key */
266                 if (config->skc_fsname[0] != '\0') {
267                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
268                                       "lustre:%s", config->skc_fsname);
269                         if (rc >= SK_DESCRIPTION_SIZE)
270                                 goto out;
271                         if (sk_load_key(config, description) == -1)
272                                 goto out;
273                 }
274
275                 /* Load client MGC keys */
276                 for (i = 0; i < MAX_MGSNIDS; i++) {
277                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
278                                 continue;
279                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
280                                       "lustre:MGC%s",
281                                       libcfs_nid2str(config->skc_mgsnids[i]));
282                         if (rc >= SK_DESCRIPTION_SIZE)
283                                 goto out;
284                         if (sk_load_key(config, description) == -1)
285                                 goto out;
286                 }
287         }
288
289         rc2 = 0;
290
291 out:
292         free(config);
293         return rc2;
294 }
295
296 /**
297  * Byte swaps config from cpu format to disk
298  *
299  * \param[in,out]       config          sk_keyfile_config to swap
300  */
301 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
302 {
303         int i;
304
305         if (!config)
306                 return;
307
308         config->skc_version = htobe32(config->skc_version);
309         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
310         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
311         config->skc_expire = htobe32(config->skc_expire);
312         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
313         config->skc_prime_bits = htobe32(config->skc_prime_bits);
314
315         for (i = 0; i < MAX_MGSNIDS; i++)
316                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
317 }
318
319 /**
320  * Byte swaps config from disk format to cpu
321  *
322  * \param[in,out]       config          sk_keyfile_config to swap
323  */
324 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
325 {
326         int i;
327
328         if (!config)
329                 return;
330
331         config->skc_version = be32toh(config->skc_version);
332         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
333         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
334         config->skc_expire = be32toh(config->skc_expire);
335         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
336         config->skc_prime_bits = be32toh(config->skc_prime_bits);
337
338         for (i = 0; i < MAX_MGSNIDS; i++)
339                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
340 }
341
342 /**
343  * Verifies the on key payload format is valid
344  *
345  * \param[in]   config          sk_keyfile_config
346  *
347  * \return      -1      failure
348  * \return      0       success
349  */
350 int sk_validate_config(const struct sk_keyfile_config *config)
351 {
352         int i;
353
354         if (!config) {
355                 printerr(0, "Null configuration passed\n");
356                 return -1;
357         }
358
359         if (config->skc_version != SK_CONF_VERSION) {
360                 printerr(0, "Invalid version\n");
361                 return -1;
362         }
363
364         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
365                 printerr(0, "Invalid HMAC algorithm\n");
366                 return -1;
367         }
368
369         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
370                 printerr(0, "Invalid crypt algorithm\n");
371                 return -1;
372         }
373
374         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
375                 /* Try to limit key expiration to some reasonable minimum and
376                  * also prevent values over INT_MAX because there appears
377                  * to be a type conversion issue */
378                 printerr(0, "Invalid expiration time should be between %d "
379                          "and %d\n", 60, INT_MAX);
380                 return -1;
381         }
382         if (config->skc_prime_bits % 8 != 0 ||
383             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
384                 printerr(0, "Invalid session key length must be a multiple of 8"
385                          " and less then %d bits\n",
386                          SK_MAX_P_BYTES * 8);
387                 return -1;
388         }
389         if (config->skc_shared_keylen % 8 != 0 ||
390             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
391                 printerr(0, "Invalid shared key max length must be a multiple "
392                          "of 8 and less then %d bits\n",
393                          SK_MAX_KEYLEN_BYTES * 8);
394                 return -1;
395         }
396
397         /* Check for terminating nulls on strings */
398         for (i = 0; i < sizeof(config->skc_fsname) &&
399              config->skc_fsname[i] != '\0';  i++)
400                 ; /* empty loop */
401         if (i == sizeof(config->skc_fsname)) {
402                 printerr(0, "File system name not null terminated\n");
403                 return -1;
404         }
405
406         for (i = 0; i < sizeof(config->skc_nodemap) &&
407              config->skc_nodemap[i] != '\0';  i++)
408                 ; /* empty loop */
409         if (i == sizeof(config->skc_nodemap)) {
410                 printerr(0, "Nodemap name not null terminated\n");
411                 return -1;
412         }
413
414         if (config->skc_type == SK_TYPE_INVALID) {
415                 printerr(0, "Invalid key type\n");
416                 return -1;
417         }
418
419         return 0;
420 }
421
422 /**
423  * Hashes \a string and places the hash in \a hash
424  * at \a hash
425  *
426  * \param[in]           string          Null terminated string to hash
427  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
428  * \param[in,out]       hash            gss_buffer_desc to hold the result
429  *
430  * \return      -1      failure
431  * \return      0       success
432  */
433 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
434                           gss_buffer_desc *hash)
435 {
436         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
437         size_t len = strlen(string);
438         unsigned int hashlen;
439
440         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
441                 goto out_err;
442         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
443                 goto out_err;
444         if (!EVP_DigestUpdate(ctx, string, len))
445                 goto out_err;
446         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
447                 goto out_err;
448
449         EVP_MD_CTX_destroy(ctx);
450         hash->length = hashlen;
451         return 0;
452
453 out_err:
454         EVP_MD_CTX_destroy(ctx);
455         return -1;
456 }
457
458 /**
459  * Hashes \a string and verifies the resulting hash matches the value
460  * in \a current_hash
461  *
462  * \param[in]           string          Null terminated string to hash
463  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
464  * \param[in,out]       current_hash    gss_buffer_desc to compare to
465  *
466  * \return      gss error       failure
467  * \return      GSS_S_COMPLETE  success
468  */
469 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
470                         const gss_buffer_desc *current_hash)
471 {
472         gss_buffer_desc hash;
473         unsigned char hashbuf[EVP_MAX_MD_SIZE];
474
475         hash.value = hashbuf;
476         hash.length = sizeof(hashbuf);
477
478         if (sk_hash_string(string, hash_alg, &hash))
479                 return GSS_S_FAILURE;
480         if (current_hash->length != hash.length)
481                 return GSS_S_DEFECTIVE_TOKEN;
482         if (memcmp(current_hash->value, hash.value, hash.length))
483                 return GSS_S_BAD_SIG;
484
485         return GSS_S_COMPLETE;
486 }
487
488 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
489                                        const char *mgsnid)
490 {
491         lnet_nid_t nid;
492         int i;
493
494         nid = libcfs_str2nid(mgsnid);
495         if (nid == LNET_NID_ANY)
496                 return 0;
497
498         for (i = 0; i < MAX_MGSNIDS; i++)
499                 if  (config->skc_mgsnids[i] == nid)
500                         return 1;
501         return 0;
502 }
503
504 /**
505  * Create an sk_cred structure populated with initial configuration info and the
506  * key.  \a tgt and \a nodemap are used in determining the expected key
507  * description so the key can be found by searching the keyring.
508  * This is done because there is no easy way to pass keys from the mount command
509  * all the way to the request_key call.  In addition any keys can be dynamically
510  * added to the keyrings and still found.  The keyring that needs to be used
511  * must be the session keyring.
512  *
513  * \param[in]   tgt             Target file system
514  * \param[in]   nodemap         Cluster name for the key.  This correlates to
515  *                              the nodemap name and is used by the server side.
516  *                              For the client this will be NULL.
517  * \param[in]   flags           Flags for the credentials
518  *
519  * \return      sk_cred Allocated struct sk_cred on success
520  * \return      NULL    failure
521  */
522 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
523                                const uint32_t flags)
524 {
525         struct sk_keyfile_config *config;
526         struct sk_kernel_ctx *kctx;
527         struct sk_cred *skc = NULL;
528         char description[SK_DESCRIPTION_SIZE + 1];
529         char fsname[MTI_NAME_MAXLEN + 1];
530         const char *mgsnid = NULL;
531         char *ptr;
532         long sk_key;
533         int keylen;
534         int len;
535         int rc;
536
537         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
538                  tgt, nodemap);
539
540         memset(description, 0, sizeof(description));
541         memset(fsname, 0, sizeof(fsname));
542
543         /* extract the file system name from target */
544         ptr = index(tgt, '-');
545         if (!ptr) {
546                 len = strlen(tgt);
547
548                 /* This must be an MGC target */
549                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
550                         printerr(0, "Invalid target name\n");
551                         return NULL;
552                 }
553                 mgsnid = tgt + 3;
554         } else {
555                 len = ptr - tgt;
556         }
557
558         if (len > MTI_NAME_MAXLEN) {
559                 printerr(0, "Invalid target name\n");
560                 return NULL;
561         }
562         memcpy(fsname, tgt, len);
563
564         if (nodemap) {
565                 if (mgsnid)
566                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
567                                       "lustre:MGS:%s", nodemap);
568                 else
569                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
570                                       "lustre:%s:%s", fsname, nodemap);
571         } else {
572                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
573                               fsname);
574         }
575
576         if (rc >= SK_DESCRIPTION_SIZE) {
577                 printerr(0, "Invalid key description\n");
578                 return NULL;
579         }
580
581         /* It may be a good idea to move Lustre keys to the gss_keyring
582          * (lgssc) type so that they expire when Lustre modules are removed.
583          * Unfortunately it can't be done at mount time because the mount
584          * syscall could trigger the Lustre modules to load and until that
585          * point we don't have a lgssc key type.
586          *
587          * TODO: Query the community for a consensus here  */
588         printerr(2, "Searching for key with description: %s\n", description);
589         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
590                                description, 0);
591         if (sk_key == -1) {
592                 printerr(1, "No key found for %s\n", description);
593                 return NULL;
594         }
595
596         keylen = keyctl_read_alloc(sk_key, (void **)&config);
597         if (keylen == -1) {
598                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
599                          strerror(errno));
600                 return NULL;
601         } else if (keylen != sizeof(*config)) {
602                 printerr(0, "Unexpected key size: %d returned for key %ld, "
603                          "expected %zu bytes\n",
604                          keylen, sk_key, sizeof(*config));
605                 goto out_err;
606         }
607
608         sk_config_disk_to_cpu(config);
609
610         if (sk_validate_config(config)) {
611                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
612                 goto out_err;
613         }
614
615         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
616                 printerr(0, "Target name does not match key's MGS NIDs\n");
617                 goto out_err;
618         }
619
620         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
621                 printerr(0, "Target name does not match key's file system\n");
622                 goto out_err;
623         }
624
625         skc = malloc(sizeof(*skc));
626         if (!skc) {
627                 printerr(0, "Failed to allocate memory for sk_cred\n");
628                 goto out_err;
629         }
630
631         /* this initializes all gss_buffer_desc to empty as well */
632         memset(skc, 0, sizeof(*skc));
633
634         skc->sc_flags = flags;
635         skc->sc_tgt.length = strlen(tgt) + 1;
636         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
637         if (!skc->sc_tgt.value) {
638                 printerr(0, "Failed to allocate memory for target\n");
639                 goto out_err;
640         }
641         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
642
643         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
644         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
645         if (!skc->sc_nodemap_hash.value) {
646                 printerr(0, "Failed to allocate memory for nodemap hash\n");
647                 goto out_err;
648         }
649
650         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
651                            &skc->sc_nodemap_hash)) {
652                 printerr(0, "Failed to generate hash for nodemap name\n");
653                 goto out_err;
654         }
655
656         kctx = &skc->sc_kctx;
657         kctx->skc_version = config->skc_version;
658         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
659         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
660         kctx->skc_expire = config->skc_expire;
661
662         /* key payload format is in bits, convert to bytes */
663         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
664         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
665         if (!kctx->skc_shared_key.value) {
666                 printerr(0, "Failed to allocate memory for shared key\n");
667                 goto out_err;
668         }
669         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
670                kctx->skc_shared_key.length);
671
672         skc->sc_p.length = config->skc_prime_bits / 8;
673         skc->sc_p.value = malloc(skc->sc_p.length);
674         if (!skc->sc_p.value) {
675                 printerr(0, "Failed to allocate p\n");
676                 goto out_err;
677         }
678         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
679
680         free(config);
681
682         return skc;
683
684 out_err:
685         sk_free_cred(skc);
686
687         free(config);
688         return NULL;
689 }
690
691 #define SK_GENERATOR 2
692 #define DH_NUMBER_ITERATIONS_FOR_PRIME 64
693
694 /* OpenSSL 1.1.1c increased the number of rounds used for Miller-Rabin testing
695  * of the prime provided as input parameter to DH_check(). This makes the check
696  * roughly x10 longer, and causes request timeouts when an SSK flavor is being
697  * used.
698  * Instead, use a dynamic number Miller-Rabin rounds based on the speed of the
699  * check on the current system, evaluated when the lsvcgssd daemon starts, but
700  * at least as many as OpenSSL 1.1.1b used for the same key size. If default
701  * DH_check() duration is OK, use it directly instead of limiting the rounds.
702  * If \a num_rounds == 0, we just call original DH_check() directly.
703  *
704  * OpenSSL v3 internally forces a minimum of 64 rounds when checking prime, so
705  * it is no longer possible to test prime check speed with fewer rounds. In this
706  * case, do not bother and directly call EVP_PKEY_param_check.
707  */
708 #ifdef HAVE_OPENSSL_EVP_PKEY
709 static bool sk_is_dh_valid(EVP_PKEY_CTX *ctx)
710 {
711         if (EVP_PKEY_param_check(ctx) != 1) {
712                 printerr(0, "EVP_PKEY_param_check failed\n");
713                 ERR_print_errors_fp(stderr);
714                 return false;
715         }
716         return true;
717 }
718 #else
719 static inline bool sk_check_dh(const DH *dh, int num_rounds, bool fullcheck)
720 {
721         const BIGNUM *p = NULL, *g = NULL;
722         BN_ULONG word;
723         BN_CTX *ctx;
724         BIGNUM *r;
725         bool valid = false;
726         int rc;
727
728         DH_get0_pqg(dh, &p, NULL, &g);
729         if (!p || !g)
730                 return false;
731
732         if (!BN_is_word(g, SK_GENERATOR)) {
733                 printerr(0, "%s: Diffie-Hellman generator is not %u\n",
734                          program_invocation_short_name, SK_GENERATOR);
735                 return false;
736         }
737
738         word = BN_mod_word(p, 24);
739         /* OpenSSL v3 changed the way the prime is generated,
740          * using p mod 24 == 23.
741          * So we must accept word == 23 if the prime was generated
742          * by a client with OpenSSL v3.
743          */
744         if ((word != 11) && (word != 23)) {
745                 printerr(0, "%s: Diffie-Hellman prime modulo=%lu unsuitable\n",
746                          program_invocation_short_name, word);
747                 return false;
748         }
749
750         if (!fullcheck)
751                 return true;
752
753         ctx = BN_CTX_new();
754         if (ctx == NULL) {
755                 printerr(0, "%s: Diffie-Hellman error allocating context\n",
756                          program_invocation_short_name);
757                 return false;
758         }
759         BN_CTX_start(ctx);
760         r = BN_CTX_get(ctx); /* must be called before "ctx" used elsewhere */
761
762         rc = BN_is_prime_ex(p, num_rounds, ctx, NULL);
763         if (rc == 0)
764                 printerr(0, "%s: Diffie-Hellman 'p' not prime in %u rounds\n",
765                          program_invocation_short_name, num_rounds);
766         if (rc <= 0)
767                 goto out_free;
768
769         if (!BN_rshift1(r, p)) {
770                 printerr(0, "%s: error shifting BigNum 'r' by 'p'\n",
771                          program_invocation_short_name);
772                 goto out_free;
773         }
774         rc = BN_is_prime_ex(r, num_rounds, ctx, NULL);
775         if (rc == 0)
776                 printerr(0, "%s: Diffie-Hellman 'r' not prime in %u rounds\n",
777                          program_invocation_short_name, num_rounds);
778         if (rc <= 0)
779                 goto out_free;
780
781         valid = true;
782
783 out_free:
784         BN_CTX_end(ctx);
785         BN_CTX_free(ctx);
786
787         return valid;
788 }
789
790 static bool sk_is_dh_valid(const DH *dh, int num_rounds)
791 {
792         int rc;
793
794         if (num_rounds == 0) {
795                 int codes = 0;
796
797                 rc = DH_check(dh, &codes);
798                 if (codes == DH_NOT_SUITABLE_GENERATOR &&
799                     sk_check_dh(dh, num_rounds, false))
800                         return true;
801                 if (rc != 1 || codes) {
802                         printerr(0, "DH_check(0) failed: codes=%#x: rc=%d\n",
803                                  codes, rc);
804                         return false;
805                 }
806                 return true;
807         }
808
809         return sk_check_dh(dh, num_rounds, true);
810 }
811 #endif
812
813 #ifndef HAVE_OPENSSL_EVP_PKEY
814 #define VALUE_LENGTH 256
815 static unsigned char test_prime[VALUE_LENGTH] =
816         "\xf7\xfa\x49\xd8\xec\xb1\x3b\xff\x26\x10\x3f\xc5\x3a\xc5\xcc\x40"
817         "\x4f\xbf\x92\xe1\x8b\x83\xe7\xa2\xba\x0f\x51\x5a\x91\x48\xe0\xa3"
818         "\xf1\x4d\xbc\xbb\x8a\x28\x14\xac\x02\x23\x76\x42\x17\x4d\x3c\xdc"
819         "\x5e\x4f\x80\x1f\xd7\x54\x1c\x50\xac\x3b\x28\x68\x8d\x71\x41\x7f"
820         "\xa7\x1c\x2f\x22\xd3\xa8\x91\xb2\x64\xb6\x84\xa6\xcf\x06\x16\x91"
821         "\x2f\xb8\xb4\x42\x1d\x3a\x4e\x3a\x0c\x7f\x04\x69\x78\xb5\x8f\x92"
822         "\x07\x89\xac\x24\x06\x53\x2c\x23\xec\xaa\x5c\xb4\x7b\x49\xbc\xf4"
823         "\x90\x67\x71\x9c\x24\x2c\x1d\x8d\x76\xc8\x85\x4e\x19\xf1\xf9\x33"
824         "\x45\xbd\x9f\x7d\x0a\x08\x8c\x22\xcc\x35\xf3\x5b\xab\x3f\x24\x9d"
825         "\x61\x70\x86\xbb\xbe\xd8\xb0\xf8\x34\xfa\xeb\x5b\x8e\xf2\x62\x23"
826         "\xd1\xfb\xbb\xb8\x21\x71\x1e\x39\x39\x59\xe0\x82\x98\x41\x84\x40"
827         "\x1f\xd3\x9b\xa3\x73\xdb\xec\xe0\xc0\xde\x2d\x1c\xea\x43\x40\x93"
828         "\x98\x38\x03\x36\x1e\xe1\xe7\x39\x7b\x35\x92\x4a\x51\xa5\x91\x63"
829         "\xd5\x31\x98\x3d\x89\x27\x6f\xcc\x69\xff\xbe\x31\x13\xdc\x2f\x72"
830         "\x2d\xab\x6a\xb7\x13\xd3\x47\xda\xaa\xf3\x3c\xa0\xfd\xaa\x0f\x02"
831         "\x96\x81\x1a\x26\xe8\xf7\x25\x65\x33\x78\xd9\x6b\x6d\xb0\xd9\xfb";
832
833 /**
834  * Measure time taken by prime testing routine for a 2048 bit long prime,
835  * depending on the number of check rounds.
836  *
837  * \param[in]   usec_check_max    max time allowed for DH_check completion
838  *
839  * \retval      max number of rounds to keep prime testing under usec_check_max
840  *              return 0 if we should use the default DH_check rounds
841  */
842 int sk_speedtest_dh_valid(unsigned int usec_check_max, pid_t *child)
843 {
844         DH *dh;
845         BIGNUM *p, *g;
846         struct sigaction sa;
847         int num_rounds, prev_rounds = 0;
848
849         /* Set SIGCHLD disposition so that child that terminates
850          * does not become a zombie.
851          */
852         sa.sa_handler = NULL;
853         sigemptyset(&sa.sa_mask);
854         sa.sa_flags = SA_NOCLDWAIT;
855         if (sigaction(SIGCHLD, &sa, NULL) == -1)
856                 printerr(2, "SIGCHLD sigaction failed\n");
857
858         *child = fork();
859         if (*child == -1) {
860                 printerr(0, "cannot fork child for speedtest: %s\n",
861                          strerror(errno));
862                 /* in this case the speedtest cannot be run, so return 0
863                  * to use default num rounds for prime testing
864                  */
865                 return 0;
866         } else if (*child != 0) {
867                 /* parent returns immediately,
868                  * 0 means num rounds is not determined yet
869                  */
870                 return 0;
871         }
872
873         /* now in forked child process, start speed test */
874
875         dh = DH_new();
876         if (!dh)
877                 return 0;
878
879         p = BN_bin2bn(test_prime, VALUE_LENGTH, NULL);
880         if (!p)
881                 goto free_dh;
882
883         g = BN_new();
884         if (!g)
885                 goto free_p;
886
887         if (!BN_set_word(g, SK_GENERATOR))
888                 goto free_g;
889
890         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
891         if (!DH_set0_pqg(dh, p, NULL, g)) {
892         free_g:
893                 BN_free(g);
894         free_p:
895                 BN_free(p);
896                 goto free_dh;
897         }
898
899         for (num_rounds = 0;
900              num_rounds <= DH_NUMBER_ITERATIONS_FOR_PRIME;
901              num_rounds += (num_rounds <= 4 ? 4 : 8)) {
902                 unsigned int usec_this;
903                 int j;
904
905                 /* get max duration of 4 runs at current number of rounds */
906                 usec_this = 0;
907                 for (j = 0; j < 4; j++) {
908                         struct timeval now, prev;
909                         unsigned int usec_curr;
910
911                         gettimeofday(&prev, NULL);
912                         if (!sk_is_dh_valid(dh, num_rounds)) {
913                                 /* if test_prime is found bad, use default */
914                                 prev_rounds = 0;
915                                 goto free_dh;
916                         }
917                         gettimeofday(&now, NULL);
918                         usec_curr = (now.tv_sec - prev.tv_sec) * 1000000 +
919                                     now.tv_usec - prev.tv_usec;
920                         if (usec_curr > usec_this)
921                                 usec_this = usec_curr;
922                 }
923                 printerr(2, "%s: %d rounds: %d usec\n",
924                          program_invocation_short_name, num_rounds, usec_this);
925                 if (num_rounds == 0) {
926                         if (usec_this <= usec_check_max)
927                         /* using original check rounds as implemented in
928                          * DH_check() took less time than the max allowed,
929                          * so just use original DH_check()
930                          */
931                                 break;
932                 } else if (usec_this >= usec_check_max) {
933                         break;
934                 }
935                 prev_rounds = num_rounds;
936         }
937
938 free_dh:
939         DH_free(dh);
940
941         return prev_rounds;
942 }
943 #endif /* !HAVE_OPENSSL_EVP_PKEY */
944
945 #ifdef HAVE_OPENSSL_EVP_PKEY
946 static uint32_t __sk_gen_params(struct sk_cred *skc, BIGNUM *p, BIGNUM *g,
947                                 int num_rounds)
948 {
949         EVP_PKEY_CTX *ctx = NULL, *ctx_from_key = NULL;
950         OSSL_PARAM_BLD *tmpl = NULL;
951         OSSL_PARAM *params = NULL;
952         EVP_PKEY *key = NULL;
953         uint32_t rc = GSS_S_FAILURE;
954
955         tmpl = OSSL_PARAM_BLD_new();
956         if (!tmpl ||
957             !OSSL_PARAM_BLD_push_BN(tmpl, OSSL_PKEY_PARAM_FFC_P, p) ||
958             !OSSL_PARAM_BLD_push_BN(tmpl, OSSL_PKEY_PARAM_FFC_G, g)) {
959                 printerr(0, "error: params cannot be pushed\n");
960                 goto err;
961         }
962         params = OSSL_PARAM_BLD_to_param(tmpl);
963         if (!params) {
964                 printerr(0, "error: params cannot be allocated\n");
965                 goto err;
966         }
967
968         ctx = EVP_PKEY_CTX_new_from_name(NULL, "DH", NULL);
969         if (!ctx ||
970             EVP_PKEY_fromdata_init(ctx) != 1 ||
971             EVP_PKEY_fromdata(ctx, &key,
972                               EVP_PKEY_KEY_PARAMETERS, params) != 1) {
973                 printerr(0, "error: params cannot be set\n");
974                 goto err;
975         }
976
977         ctx_from_key = EVP_PKEY_CTX_new_from_pkey(NULL, key, NULL);
978         if (!ctx_from_key) {
979                 printerr(0, "error: ctx_from_key cannot be allocated\n");
980                 goto err;
981         }
982
983         /* Verify that we have a safe prime and valid generator */
984         if (!sk_is_dh_valid(ctx_from_key))
985                 goto err;
986
987         skc->sc_params = NULL;
988         if (EVP_PKEY_keygen_init(ctx_from_key) != 1 ||
989             EVP_PKEY_keygen(ctx_from_key, &skc->sc_params) != 1) {
990                 printerr(0, "Failed to generate public DH key: %s\n",
991                          ERR_error_string(ERR_get_error(), NULL));
992                 goto err;
993         }
994
995         /* skc->sc_pub_key.value is allocated by
996          * EVP_PKEY_get1_encoded_public_key
997          */
998         skc->sc_pub_key.length =
999           EVP_PKEY_get1_encoded_public_key(skc->sc_params,
1000                                       (unsigned char **)&skc->sc_pub_key.value);
1001         if (skc->sc_pub_key.length == 0) {
1002                 printerr(0, "error: cannot get pub key\n");
1003                 skc->sc_pub_key.value = NULL;
1004                 goto err;
1005         }
1006         rc = GSS_S_COMPLETE;
1007
1008 err:
1009         EVP_PKEY_CTX_free(ctx_from_key);
1010         EVP_PKEY_free(key);
1011         EVP_PKEY_CTX_free(ctx);
1012         OSSL_PARAM_free(params);
1013         OSSL_PARAM_BLD_free(tmpl);
1014         BN_free(g);
1015         BN_free(p);
1016         return rc;
1017 }
1018 #else /* !HAVE_OPENSSL_EVP_PKEY */
1019 static uint32_t __sk_gen_params(struct sk_cred *skc, BIGNUM *p, BIGNUM *g,
1020                                 int num_rounds)
1021 {
1022         const BIGNUM *pub_key;
1023
1024         /* Populate DH parameters */
1025         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
1026         skc->sc_params = DH_new();
1027         if (!skc->sc_params || !DH_set0_pqg(skc->sc_params, p, NULL, g)) {
1028                 printerr(0, "Failed to set pqg\n");
1029                 BN_free(g);
1030                 BN_free(p);
1031                 return GSS_S_FAILURE;
1032         }
1033
1034         /* Verify that we have a safe prime and valid generator */
1035         if (!sk_is_dh_valid(skc->sc_params, num_rounds))
1036                 return GSS_S_FAILURE;
1037
1038         if (DH_generate_key(skc->sc_params) != 1) {
1039                 printerr(0, "Failed to generate public DH key: %s\n",
1040                          ERR_error_string(ERR_get_error(), NULL));
1041                 return GSS_S_FAILURE;
1042         }
1043
1044         DH_get0_key(skc->sc_params, &pub_key, NULL);
1045         skc->sc_pub_key.length = BN_num_bytes(pub_key);
1046         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
1047         if (!skc->sc_pub_key.value) {
1048                 printerr(0, "Failed to allocate memory for public key\n");
1049                 return GSS_S_FAILURE;
1050         }
1051
1052         BN_bn2bin(pub_key, skc->sc_pub_key.value);
1053
1054         return GSS_S_COMPLETE;
1055 }
1056 #endif /* HAVE_OPENSSL_EVP_PKEY */
1057
1058 /**
1059  * Populates the DH parameters for the DHKE
1060  *
1061  * \param[in,out]       skc             Shared key credentials structure to
1062  *                                      populate with DH parameters
1063  *
1064  * \retval      GSS_S_COMPLETE  success
1065  * \retval      GSS_S_FAILURE   failure
1066  */
1067 uint32_t sk_gen_params(struct sk_cred *skc, int num_rounds)
1068 {
1069         uint32_t random;
1070         BIGNUM *p, *g;
1071
1072         /* Random value used by both the request and response as part of the
1073          * key binding material.  This also should ensure we have unqiue
1074          * tokens that are sent to the remote server which is important because
1075          * the token is hashed for the sunrpc cache lookups and a failure there
1076          * would cause connection attempts to fail indefinitely due to the large
1077          * timeout value on the server side.
1078          */
1079         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
1080                 printerr(0, "Failed to get data for random parameter: %s\n",
1081                          ERR_error_string(ERR_get_error(), NULL));
1082                 return GSS_S_FAILURE;
1083         }
1084
1085         /* The random value will always be used in byte range operations
1086          * so we keep it as big endian from this point on.
1087          */
1088         skc->sc_kctx.skc_host_random = random;
1089
1090         p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
1091         if (!p) {
1092                 printerr(0, "Failed to convert binary to BIGNUM\n");
1093                 return GSS_S_FAILURE;
1094         }
1095
1096         /* We use a static generator for shared key */
1097         g = BN_new();
1098         if (!g) {
1099                 printerr(0, "Failed to allocate new BIGNUM\n");
1100                 goto free_p;
1101         }
1102         if (BN_set_word(g, SK_GENERATOR) != 1) {
1103                 printerr(0, "Failed to set g value for DH params\n");
1104                 goto free_g;
1105         }
1106
1107         return __sk_gen_params(skc, p, g, num_rounds);
1108
1109 free_g:
1110         BN_free(g);
1111 free_p:
1112         BN_free(p);
1113
1114         return GSS_S_FAILURE;
1115 }
1116
1117 /**
1118  * Convert SK hash algorithm into openssl message digest
1119  *
1120  * \param[in,out]       alg             SK hash algorithm
1121  *
1122  * \retval              EVP_MD
1123  */
1124 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
1125 {
1126         switch (alg) {
1127         case CFS_HASH_ALG_SHA256:
1128                 return EVP_sha256();
1129         case CFS_HASH_ALG_SHA512:
1130                 return EVP_sha512();
1131         default:
1132                 return EVP_md_null();
1133         }
1134 }
1135
1136 /**
1137  * Signs (via HMAC) the parameters used only in the key initialization protocol.
1138  *
1139  * \param[in]           key             Key to use for HMAC
1140  * \param[in]           bufs            Array of gss_buffer_desc to generate
1141  *                                      HMAC for
1142  * \param[in]           numbufs         Number of buffers in array
1143  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
1144  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
1145  *                                      in this gss_buffer_desc.  Caller must
1146  *                                      free this.
1147  *
1148  * \retval      0       success
1149  * \retval      -1      failure
1150  */
1151 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
1152                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
1153 {
1154         unsigned int hashlen = EVP_MD_size(hash_alg);
1155         EVP_MAC_CTX *ctx = NULL;
1156         EVP_MAC *mac = NULL;
1157         size_t len = 0;
1158         int i, rc = -1;
1159         DECLARE_EVP_MD(subalg, hash_alg);
1160
1161         if (hash_alg == EVP_md_null()) {
1162                 printerr(0, "Invalid hash algorithm\n");
1163                 return -1;
1164         }
1165
1166         hmac->length = hashlen;
1167         hmac->value = malloc(hashlen);
1168         if (!hmac->value) {
1169                 printerr(0, "Failed to allocate memory for HMAC\n");
1170                 goto out;
1171         }
1172
1173         mac = EVP_MAC_fetch(NULL, "HMAC", NULL);
1174         if (!mac) {
1175                 printerr(0, "Failed to fetch HMAC\n");
1176                 goto out;
1177         }
1178
1179         ctx = EVP_MAC_CTX_new(mac);
1180         if (!ctx) {
1181                 printerr(0, "Failed to init HMAC ctx\n");
1182                 goto out;
1183         }
1184
1185         if (EVP_MAC_init(ctx, key->value, key->length, subalg) != 1) {
1186                 printerr(0, "Failed to init HMAC\n");
1187                 goto out;
1188         }
1189
1190         for (i = 0; i < numbufs; i++) {
1191                 if (EVP_MAC_update(ctx, bufs[i].value, bufs[i].length) != 1) {
1192                         printerr(0, "Failed to update HMAC\n");
1193                         goto out;
1194                 }
1195         }
1196
1197         /* The result gets populated in hmac */
1198         if (EVP_MAC_final(ctx, hmac->value, &len, hashlen) != 1) {
1199                 printerr(0, "Failed to finalize HMAC\n");
1200                 goto out;
1201         }
1202         if (hmac->length != len) {
1203                 printerr(0, "HMAC size %zu does not match expected %zu\n",
1204                          len, hmac->length);
1205                 goto out;
1206         }
1207
1208         rc = 0;
1209 out:
1210         EVP_MAC_CTX_free(ctx);
1211         EVP_MAC_free(mac);
1212         return rc;
1213 }
1214
1215 /**
1216  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
1217  * and verifies against \a hmac.
1218  *
1219  * \param[in]   skc             Shared key credentials
1220  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
1221  * \param[in]   numbufs         Number of buffers in array
1222  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
1223  * \param[in]   hmac            HMAC to verify against
1224  *
1225  * \retval      GSS_S_COMPLETE  success (match)
1226  * \retval      gss error       failure
1227  */
1228 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
1229                         const int numbufs, const EVP_MD *hash_alg,
1230                         gss_buffer_desc *hmac)
1231 {
1232         gss_buffer_desc bufs_hmac;
1233         int rc;
1234
1235         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
1236                          &bufs_hmac)) {
1237                 printerr(0, "Failed to sign buffers to verify HMAC\n");
1238                 if (bufs_hmac.value)
1239                         free(bufs_hmac.value);
1240                 return GSS_S_FAILURE;
1241         }
1242
1243         if (hmac->length != bufs_hmac.length) {
1244                 printerr(0, "Invalid HMAC size\n");
1245                 free(bufs_hmac.value);
1246                 return GSS_S_BAD_SIG;
1247         }
1248
1249         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
1250         free(bufs_hmac.value);
1251
1252         if (rc)
1253                 return GSS_S_BAD_SIG;
1254
1255         return GSS_S_COMPLETE;
1256 }
1257
1258 /**
1259  * Cleanup an sk_cred freeing any resources
1260  *
1261  * \param[in,out]       skc     Shared key credentials to free
1262  */
1263 void sk_free_cred(struct sk_cred *skc)
1264 {
1265         if (!skc)
1266                 return;
1267
1268         if (skc->sc_p.value)
1269                 free(skc->sc_p.value);
1270         if (skc->sc_pub_key.value)
1271                 free(skc->sc_pub_key.value);
1272         if (skc->sc_tgt.value)
1273                 free(skc->sc_tgt.value);
1274         if (skc->sc_nodemap_hash.value)
1275                 free(skc->sc_nodemap_hash.value);
1276         if (skc->sc_hmac.value)
1277                 free(skc->sc_hmac.value);
1278
1279         /* Overwrite keys and IV before freeing */
1280         if (skc->sc_dh_shared_key.value) {
1281                 memset(skc->sc_dh_shared_key.value, 0,
1282                        skc->sc_dh_shared_key.length);
1283                 free(skc->sc_dh_shared_key.value);
1284         }
1285         if (skc->sc_kctx.skc_hmac_key.value) {
1286                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
1287                        skc->sc_kctx.skc_hmac_key.length);
1288                 free(skc->sc_kctx.skc_hmac_key.value);
1289         }
1290         if (skc->sc_kctx.skc_encrypt_key.value) {
1291                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
1292                        skc->sc_kctx.skc_encrypt_key.length);
1293                 free(skc->sc_kctx.skc_encrypt_key.value);
1294         }
1295         if (skc->sc_kctx.skc_shared_key.value) {
1296                 memset(skc->sc_kctx.skc_shared_key.value, 0,
1297                        skc->sc_kctx.skc_shared_key.length);
1298                 free(skc->sc_kctx.skc_shared_key.value);
1299         }
1300         if (skc->sc_kctx.skc_session_key.value) {
1301                 memset(skc->sc_kctx.skc_session_key.value, 0,
1302                        skc->sc_kctx.skc_session_key.length);
1303                 free(skc->sc_kctx.skc_session_key.value);
1304         }
1305
1306         if (skc->sc_params) {
1307                 EVP_PKEY_free(skc->sc_params);
1308                 skc->sc_params = NULL;
1309         }
1310
1311         free(skc);
1312         skc = NULL;
1313 }
1314
1315 /* This function handles key derivation using the hash algorithm specified in
1316  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
1317  * \a origin_key to produce a \a derived_key.  The first element of the
1318  * key_binding_bufs array is reserved for the counter used in the KDF.  The
1319  * derived key in \a derived_key could differ in size from \a origin_key and
1320  * must be populated with the expected size and a valid buffer to hold the
1321  * contents.
1322  *
1323  * If the derived key size is greater than the HMAC algorithm size it will be
1324  * a done using several iterations of a counter and the key binding bufs.
1325  *
1326  * If the size is smaller it will take copy the first N bytes necessary to
1327  * fill the derived key. */
1328 static int sk_kdf(gss_buffer_desc *derived_key, gss_buffer_desc *origin_key,
1329                   gss_buffer_desc *key_binding_bufs, int numbufs,
1330                   enum cfs_crypto_hash_alg hmac_alg)
1331 {
1332         size_t remain;
1333         size_t bytes;
1334         uint32_t counter;
1335         char *keydata;
1336         gss_buffer_desc tmp_hash;
1337         int i;
1338         int rc;
1339
1340         if (numbufs < 1)
1341                 return -EINVAL;
1342
1343         /* Use a counter as the first buffer followed by the key binding
1344          * buffers in the event we need more than one a single cycle to
1345          * produced a symmetric key large enough in size */
1346         key_binding_bufs[0].value = &counter;
1347         key_binding_bufs[0].length = sizeof(counter);
1348
1349         remain = derived_key->length;
1350         keydata = derived_key->value;
1351         i = 0;
1352         while (remain > 0) {
1353                 counter = htobe32(i++);
1354                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1355                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1356                 if (rc) {
1357                         if (tmp_hash.value)
1358                                 free(tmp_hash.value);
1359                         return rc;
1360                 }
1361
1362                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
1363                         free(tmp_hash.value);
1364                         return -EINVAL;
1365                 }
1366
1367                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1368                 memcpy(keydata, tmp_hash.value, bytes);
1369                 free(tmp_hash.value);
1370                 remain -= bytes;
1371                 keydata += bytes;
1372         }
1373
1374         return 0;
1375 }
1376
1377 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1378  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1379  * (Sep 2014) Section 5.5.1
1380  *
1381  * \param[in,out]       skc             Shared key credentials structure with
1382  *
1383  * \return      -1              failure
1384  * \return      0               success
1385  */
1386 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1387                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1388 {
1389         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1390         gss_buffer_desc *session_key = &kctx->skc_session_key;
1391         gss_buffer_desc bufs[5];
1392         enum cfs_crypto_crypt_alg crypt_alg;
1393         int rc = -1;
1394
1395         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1396         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1397         session_key->value = malloc(session_key->length);
1398         if (!session_key->value) {
1399                 printerr(0, "Failed to allocate memory for session key\n");
1400                 return rc;
1401         }
1402
1403         /* Key binding info ordering
1404          * 1. Reserved for counter
1405          * 1. DH shared key
1406          * 2. Client's NIDs
1407          * 3. Client's token
1408          * 4. Server's token */
1409         bufs[0].value = NULL;
1410         bufs[0].length = 0;
1411         bufs[1] = skc->sc_dh_shared_key;
1412         bufs[2].value = &client_nid;
1413         bufs[2].length = sizeof(client_nid);
1414         bufs[3] = *client_token;
1415         bufs[4] = *server_token;
1416
1417         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1418                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1419 }
1420
1421 /* Uses the session key to create an HMAC key and encryption key.  In
1422  * integrity mode the session key used to generate the HMAC key uses
1423  * session information which is available on the wire but by creating
1424  * a session based HMAC key we can prevent potential replay as both the
1425  * client and server have random numbers used as part of the key creation.
1426  *
1427  * The keys used for integrity and privacy are formulated as below using
1428  * the session key that is the output of the key derivation function.  The
1429  * HMAC algorithm is determined by the shared key algorithm selected in the
1430  * key file.
1431  *
1432  * For ski mode:
1433  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1434  *
1435  * For skpi mode:
1436  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1437  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1438  *
1439  * \param[in,out]       skc             Shared key credentials structure with
1440  *
1441  * \return      -1              failure
1442  * \return      0               success
1443  */
1444 int sk_compute_keys(struct sk_cred *skc)
1445 {
1446         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1447         gss_buffer_desc *session_key = &kctx->skc_session_key;
1448         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1449         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1450         enum cfs_crypto_hash_alg hmac_alg;
1451         enum cfs_crypto_crypt_alg crypt_alg;
1452         char *encrypt = "Encrypt";
1453         char *integrity = "Integrity";
1454         int rc;
1455
1456         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1457         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1458         hmac_key->value = malloc(hmac_key->length);
1459         if (!hmac_key->value)
1460                 return -ENOMEM;
1461
1462         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1463                                session_key->length, SK_PBKDF2_ITERATIONS,
1464                                sk_hash_to_evp_md(hmac_alg),
1465                                hmac_key->length, hmac_key->value);
1466         if (rc == 0)
1467                 return -EINVAL;
1468
1469         /* Encryption key is only populated in privacy mode */
1470         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1471                 return 0;
1472
1473         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1474         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1475         encrypt_key->value = malloc(encrypt_key->length);
1476         if (!encrypt_key->value)
1477                 return -ENOMEM;
1478
1479         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1480                                session_key->length, SK_PBKDF2_ITERATIONS,
1481                                sk_hash_to_evp_md(hmac_alg),
1482                                encrypt_key->length, encrypt_key->value);
1483         if (rc == 0)
1484                 return -EINVAL;
1485
1486         return 0;
1487 }
1488
1489 static uint32_t __sk_compute_dh_key(struct sk_cred *skc,
1490                                     const gss_buffer_desc *pub_key,
1491                                     size_t *expected_len)
1492 {
1493         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1494         uint32_t rc = GSS_S_FAILURE;
1495 #ifdef HAVE_OPENSSL_EVP_PKEY
1496         EVP_PKEY_CTX *ctx = NULL;
1497         EVP_PKEY *peerkey = NULL;
1498
1499         peerkey = EVP_PKEY_new();
1500         if (!peerkey ||
1501             EVP_PKEY_copy_parameters(peerkey, skc->sc_params) != 1) {
1502                 printerr(0, "error: peerkey cannot be init\n");
1503                 goto out_err;
1504         }
1505
1506         if (EVP_PKEY_set1_encoded_public_key(peerkey,
1507                                              pub_key->value,
1508                                              pub_key->length) != 1) {
1509                 printerr(0, "error: peerkey cannot be set\n");
1510                 goto out_err;
1511         }
1512
1513         ctx = EVP_PKEY_CTX_new_from_pkey(NULL, skc->sc_params, NULL);
1514         if (!ctx) {
1515                 printerr(0, "error: ctx cannot be allocated\n");
1516                 goto out_err;
1517         }
1518
1519         if (EVP_PKEY_derive_init(ctx) != 1 ||
1520             EVP_PKEY_derive_set_peer(ctx, peerkey) != 1) {
1521                 printerr(0, "error: ctx cannot be init\n");
1522                 goto out_err;
1523         }
1524
1525         if (EVP_PKEY_derive(ctx, NULL, expected_len) != 1) {
1526                 printerr(0, "error: cannot get dh length\n");
1527                 goto out_err;
1528         }
1529
1530         dh_shared->length = *expected_len;
1531         dh_shared->value = malloc(*expected_len);
1532         if (!dh_shared->value) {
1533                 printerr(0, "error: cannot allocate memory for shared key\n");
1534                 goto out_err;
1535         }
1536
1537         if (EVP_PKEY_derive(ctx, dh_shared->value, &dh_shared->length) != 1) {
1538                 printerr(0, "error: cannot derive dh key\n");
1539                 ERR_print_errors_fp(stderr);
1540                 goto out_err;
1541         }
1542
1543         rc = GSS_S_COMPLETE;
1544 out_err:
1545         EVP_PKEY_CTX_free(ctx);
1546         EVP_PKEY_free(peerkey);
1547 #else /* !HAVE_OPENSSL_EVP_PKEY */
1548         BIGNUM *remote_pub_key;
1549
1550         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1551         if (!remote_pub_key) {
1552                 printerr(0, "Failed to convert binary to BIGNUM\n");
1553                 return rc;
1554         }
1555
1556         *expected_len = DH_size(skc->sc_params);
1557         dh_shared->length = *expected_len;
1558         dh_shared->value = malloc(*expected_len);
1559         if (!dh_shared->value) {
1560                 printerr(0,
1561                          "Failed to allocate memory for computed shared secret key\n");
1562                 goto out_err;
1563         }
1564
1565         /* This computes the shared key from the DHKE */
1566         dh_shared->length = DH_compute_key(dh_shared->value, remote_pub_key,
1567                                            skc->sc_params);
1568         if (dh_shared->length == -1) {
1569                 printerr(0, "DH key derivation failed: %s\n",
1570                          ERR_error_string(ERR_get_error(), NULL));
1571                 goto out_err;
1572         }
1573
1574         rc = GSS_S_COMPLETE;
1575 out_err:
1576         BN_free(remote_pub_key);
1577 #endif /* HAVE_OPENSSL_EVP_PKEY */
1578         return rc;
1579 }
1580
1581 /**
1582  * Computes a session key based on the DH parameters from the host and its peer
1583  *
1584  * \param[in,out]       skc             Shared key credentials structure with
1585  *                                      the session key populated with the
1586  *                                      compute key
1587  * \param[in]           pub_key         Public key returned from peer in
1588  *                                      gss_buffer_desc
1589  * \return      gss error               failure
1590  * \return      GSS_S_COMPLETE          success
1591  */
1592 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1593 {
1594         size_t expected_len;
1595         uint32_t rc;
1596
1597         rc = __sk_compute_dh_key(skc, pub_key, &expected_len);
1598         if (rc != GSS_S_COMPLETE)
1599                 return rc;
1600
1601         if (skc->sc_dh_shared_key.length < expected_len) {
1602                 /* there is around 1 chance out of 256 that the returned
1603                  * shared key is shorter than expected
1604                  */
1605                 if (skc->sc_dh_shared_key.length >= expected_len - 2) {
1606                         int shift = expected_len - skc->sc_dh_shared_key.length;
1607
1608                         /* if the key is short by only 1 or 2 bytes, just
1609                          * prepend it with 0s
1610                          */
1611                         memmove((void *)(skc->sc_dh_shared_key.value + shift),
1612                                 skc->sc_dh_shared_key.value,
1613                                 skc->sc_dh_shared_key.length);
1614                         memset(skc->sc_dh_shared_key.value, 0, shift);
1615                 } else {
1616                         /* if the key is really too short, return GSS_S_BAD_QOP
1617                          * so that the caller can retry to generate
1618                          */
1619                         printerr(0,
1620                                  "DH derivation returned a short key of %zu bytes, expected: %zu\n",
1621                                  skc->sc_dh_shared_key.length, expected_len);
1622                         rc = GSS_S_BAD_QOP;
1623                 }
1624         }
1625         return rc;
1626 }
1627
1628 /**
1629  * Creates a serialized buffer for the kernel in the order of struct
1630  * sk_kernel_ctx.
1631  *
1632  * \param[in,out]       skc             Shared key credentials structure
1633  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1634  *                                      Caller must free this buffer.
1635  *
1636  * \return      0       success
1637  * \return      -1      failure
1638  */
1639 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1640 {
1641         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1642         char *p, *end;
1643         size_t bufsize;
1644
1645         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1646                   kctx->skc_encrypt_key.length;
1647
1648         ctx_token->value = malloc(bufsize);
1649         if (!ctx_token->value)
1650                 return -1;
1651         ctx_token->length = bufsize;
1652
1653         p = ctx_token->value;
1654         end = p + ctx_token->length;
1655
1656         if (WRITE_BYTES(&p, end, kctx->skc_version))
1657                 return -1;
1658         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1659                 return -1;
1660         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1661                 return -1;
1662         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1663                 return -1;
1664         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1665                 return -1;
1666         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1667                 return -1;
1668         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1669                 return -1;
1670         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1671                 return -1;
1672
1673         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1674
1675         return 0;
1676 }
1677
1678 /**
1679  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1680  * up to \a numbufs.  Memory is allocated for each value and length
1681  * will be populated with the length
1682  *
1683  * \param[in,out]       bufs    Array of gss_buffer_descs
1684  * \param[in,out]       numbufs number of gss_buffer_desc in array
1685  * \param[in]           ns      netstring to decode
1686  *
1687  * \return      buffers populated       success
1688  * \return      -1                      failure
1689  */
1690 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1691 {
1692         char *ptr = ns->value;
1693         size_t remain = ns->length;
1694         unsigned int size;
1695         int digits;
1696         int sep;
1697         int rc;
1698         int i;
1699
1700         for (i = 0; i < numbufs; i++) {
1701                 /* read the size of first buffer */
1702                 rc = sscanf(ptr, "%9u", &size);
1703                 if (rc < 1)
1704                         goto out_err;
1705                 digits = (size) ? ceil(log10(size + 1)) : 1;
1706
1707                 /* sep of current string */
1708                 sep = size + digits + 2;
1709
1710                 /* check to make sure it's valid */
1711                 if (remain < sep || ptr[digits] != ':' ||
1712                     ptr[sep - 1] != ',')
1713                         goto out_err;
1714
1715                 bufs[i].length = size;
1716                 if (size == 0) {
1717                         bufs[i].value = NULL;
1718                 } else {
1719                         bufs[i].value = malloc(size);
1720                         if (!bufs[i].value)
1721                                 goto out_err;
1722                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1723                 }
1724
1725                 remain -= sep;
1726                 ptr += sep;
1727         }
1728
1729         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1730         return i;
1731
1732 out_err:
1733         while (i-- > 0) {
1734                 if (bufs[i].value)
1735                         free(bufs[i].value);
1736                 bufs[i].length = 0;
1737         }
1738         return -1;
1739 }
1740
1741 /**
1742  * Creates a netstring in a gss_buffer_desc that consists of all
1743  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1744  * as binary as it can contain null characters.
1745  *
1746  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1747  * \param[in]           numbufs         Number of buffers in array
1748  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1749  *                                      netstring
1750  *
1751  * \return      -1      failure
1752  * \return      0       success
1753  */
1754 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1755                         gss_buffer_desc *ns)
1756 {
1757         unsigned char *ptr;
1758         int size = 0;
1759         int rc;
1760         int i;
1761
1762         /* size of string in decimal, string size, colon, and comma */
1763         for (i = 0; i < numbufs; i++) {
1764
1765                 if (bufs[i].length == 0)
1766                         size += 3;
1767                 else
1768                         size += ceil(log10(bufs[i].length + 1)) +
1769                                 bufs[i].length + 2;
1770         }
1771
1772         ns->length = size;
1773         ns->value = malloc(ns->length);
1774         if (!ns->value) {
1775                 ns->length = 0;
1776                 return -1;
1777         }
1778
1779         ptr = ns->value;
1780         for (i = 0; i < numbufs; i++) {
1781                 /* size */
1782                 rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
1783                 ptr += rc;
1784
1785                 /* contents */
1786                 memcpy(ptr, bufs[i].value, bufs[i].length);
1787                 ptr += bufs[i].length;
1788
1789                 /* delimeter */
1790                 *ptr++ = ',';
1791
1792                 size -= bufs[i].length + rc + 1;
1793
1794                 /* should not happen */
1795                 if (size < 0)
1796                         abort();
1797         }
1798
1799         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1800         return 0;
1801 }