Whamcloud - gitweb
LU-15896 gss: support OpenSSLv3
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #ifdef HAVE_OPENSSL_EVP_PKEY
41 #include <openssl/param_build.h>
42 #endif
43 #include <sys/types.h>
44 #include <sys/stat.h>
45 #include <libcfs/util/string.h>
46 #include <sys/time.h>
47
48 #include "sk_utils.h"
49 #include "write_bytes.h"
50
51 #define SK_PBKDF2_ITERATIONS 10000
52
53 #ifdef _NEW_BUILD_
54 # include "lgss_utils.h"
55 #else
56 # include "gss_util.h"
57 # include "gss_oids.h"
58 # include "err_util.h"
59 #endif
60
61 #ifdef _ERR_UTIL_H_
62 /**
63  * Initializes logging
64  * \param[in]   program         Program name to output
65  * \param[in]   verbose         Verbose flag
66  * \param[in]   fg              Whether or not to run in foreground
67  *
68  */
69 void sk_init_logging(char *program, int verbose, int fg)
70 {
71         initerr(program, verbose, fg);
72 }
73 #endif
74
75 /**
76  * Loads the key from \a filename and returns the struct sk_keyfile_config.
77  * It should be freed by the caller.
78  *
79  * \param[in]   filename                Disk or key payload data
80  *
81  * \return      sk_keyfile_config       sucess
82  * \return      NULL                    failure
83  */
84 struct sk_keyfile_config *sk_read_file(char *filename)
85 {
86         struct sk_keyfile_config *config;
87         char *ptr;
88         size_t rc;
89         size_t remain;
90         int fd;
91
92         config = malloc(sizeof(*config));
93         if (!config) {
94                 printerr(0, "Failed to allocate memory for config\n");
95                 return NULL;
96         }
97
98         /* allow standard input override */
99         if (strcmp(filename, "-") == 0)
100                 fd = STDIN_FILENO;
101         else
102                 fd = open(filename, O_RDONLY);
103
104         if (fd == -1) {
105                 printerr(0, "Error opening key file '%s': %s\n", filename,
106                          strerror(errno));
107                 goto out_free;
108         } else if (fd != STDIN_FILENO) {
109                 struct stat st;
110
111                 rc = fstat(fd, &st);
112                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
113                         fprintf(stderr, "warning: "
114                                 "secret key '%s' has insecure file mode %#o\n",
115                                 filename, st.st_mode);
116         }
117
118         ptr = (char *)config;
119         remain = sizeof(*config);
120         while (remain > 0) {
121                 rc = read(fd, ptr, remain);
122                 if (rc == -1) {
123                         if (errno == EINTR)
124                                 continue;
125                         printerr(0, "read() failed on %s: %s\n", filename,
126                                  strerror(errno));
127                         goto out_close;
128                 } else if (rc == 0) {
129                         printerr(0, "File %s does not have a complete key\n",
130                                  filename);
131                         goto out_close;
132                 }
133                 ptr += rc;
134                 remain -= rc;
135         }
136
137         if (fd != STDIN_FILENO)
138                 close(fd);
139         sk_config_disk_to_cpu(config);
140         return config;
141
142 out_close:
143         close(fd);
144 out_free:
145         free(config);
146         return NULL;
147 }
148
149 /**
150  * Checks if a key matching \a description is found in the keyring for
151  * logging purposes and then attempts to load \a payload of \a psize into a key
152  * with \a description.
153  *
154  * \param[in]   payload         Key payload
155  * \param[in]   psize           Payload size
156  * \param[in]   description     Description used for key in keyring
157  *
158  * \return      0       sucess
159  * \return      -1      failure
160  */
161 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
162                                 const char *description)
163 {
164         struct sk_keyfile_config payload;
165         key_serial_t key;
166
167         memcpy(&payload, skc, sizeof(*skc));
168
169         /* In the keyring use the disk layout so keyctl pipe can be used */
170         sk_config_cpu_to_disk(&payload);
171
172         /* Check to see if a key is already loaded matching description */
173         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
174         if (key != -1)
175                 printerr(2, "Key %d found in session keyring, replacing\n",
176                          key);
177
178         key = add_key("user", description, &payload, sizeof(payload),
179                       KEY_SPEC_USER_KEYRING);
180         if (key != -1) {
181                 key_perm_t perm = KEY_POS_ALL | KEY_USR_ALL |
182                         KEY_GRP_ALL | KEY_OTH_ALL;
183
184                 if (keyctl_setperm(key, perm) < 0)
185                         printerr(2, "Failed to set perm 0x%x on key %d\n",
186                                  perm, key);
187                 printerr(2, "Added key %d with description %s\n", key,
188                          description);
189         } else {
190                 printerr(0, "Failed to add key with %s\n", description);
191         }
192
193         return key;
194 }
195
196 /**
197  * Reads the key from \a path, verifies it and loads into the session keyring
198  * using a description determined by the the \a type.  Existing keys with the
199  * same description are replaced.
200  *
201  * \param[in]   path    Path to key file
202  * \param[in]   type    Type of key to load which determines the description
203  *
204  * \return      0       sucess
205  * \return      -1      failure
206  */
207 int sk_load_keyfile(char *path)
208 {
209         struct sk_keyfile_config *config;
210         char description[SK_DESCRIPTION_SIZE + 1];
211         struct stat buf;
212         int i;
213         int rc;
214         int rc2 = -1;
215
216         rc = stat(path, &buf);
217         if (rc == -1) {
218                 printerr(0, "stat() failed for file %s: %s\n", path,
219                          strerror(errno));
220                 return rc2;
221         }
222
223         config = sk_read_file(path);
224         if (!config)
225                 return rc2;
226
227         /* Similar to ssh, require adequate care of key files */
228         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
229                 printerr(0, "Shared key files must be read/writeable only by "
230                          "owner\n");
231                 return -1;
232         }
233
234         if (sk_validate_config(config))
235                 goto out;
236
237         /* The server side can have multiple key files per file system so
238          * the nodemap name is appended to the key description to uniquely
239          * identify it */
240         if (config->skc_type & SK_TYPE_MGS) {
241                 /* Any key can be an MGS key as long as we are told to use it */
242                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
243                               config->skc_nodemap);
244                 if (rc >= SK_DESCRIPTION_SIZE)
245                         goto out;
246                 if (sk_load_key(config, description) == -1)
247                         goto out;
248         }
249         if (config->skc_type & SK_TYPE_SERVER) {
250                 /* Server keys need to have the file system name in the key */
251                 if (!config->skc_fsname) {
252                         printerr(0, "Key configuration has no file system "
253                                  "attribute.  Can't load as server type\n");
254                         goto out;
255                 }
256                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
257                               config->skc_fsname, config->skc_nodemap);
258                 if (rc >= SK_DESCRIPTION_SIZE)
259                         goto out;
260                 if (sk_load_key(config, description) == -1)
261                         goto out;
262         }
263         if (config->skc_type & SK_TYPE_CLIENT) {
264                 /* Load client file system key */
265                 if (config->skc_fsname) {
266                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
267                                       "lustre:%s", config->skc_fsname);
268                         if (rc >= SK_DESCRIPTION_SIZE)
269                                 goto out;
270                         if (sk_load_key(config, description) == -1)
271                                 goto out;
272                 }
273
274                 /* Load client MGC keys */
275                 for (i = 0; i < MAX_MGSNIDS; i++) {
276                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
277                                 continue;
278                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
279                                       "lustre:MGC%s",
280                                       libcfs_nid2str(config->skc_mgsnids[i]));
281                         if (rc >= SK_DESCRIPTION_SIZE)
282                                 goto out;
283                         if (sk_load_key(config, description) == -1)
284                                 goto out;
285                 }
286         }
287
288         rc2 = 0;
289
290 out:
291         free(config);
292         return rc2;
293 }
294
295 /**
296  * Byte swaps config from cpu format to disk
297  *
298  * \param[in,out]       config          sk_keyfile_config to swap
299  */
300 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
301 {
302         int i;
303
304         if (!config)
305                 return;
306
307         config->skc_version = htobe32(config->skc_version);
308         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
309         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
310         config->skc_expire = htobe32(config->skc_expire);
311         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
312         config->skc_prime_bits = htobe32(config->skc_prime_bits);
313
314         for (i = 0; i < MAX_MGSNIDS; i++)
315                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
316 }
317
318 /**
319  * Byte swaps config from disk format to cpu
320  *
321  * \param[in,out]       config          sk_keyfile_config to swap
322  */
323 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
324 {
325         int i;
326
327         if (!config)
328                 return;
329
330         config->skc_version = be32toh(config->skc_version);
331         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
332         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
333         config->skc_expire = be32toh(config->skc_expire);
334         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
335         config->skc_prime_bits = be32toh(config->skc_prime_bits);
336
337         for (i = 0; i < MAX_MGSNIDS; i++)
338                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
339 }
340
341 /**
342  * Verifies the on key payload format is valid
343  *
344  * \param[in]   config          sk_keyfile_config
345  *
346  * \return      -1      failure
347  * \return      0       success
348  */
349 int sk_validate_config(const struct sk_keyfile_config *config)
350 {
351         int i;
352
353         if (!config) {
354                 printerr(0, "Null configuration passed\n");
355                 return -1;
356         }
357
358         if (config->skc_version != SK_CONF_VERSION) {
359                 printerr(0, "Invalid version\n");
360                 return -1;
361         }
362
363         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
364                 printerr(0, "Invalid HMAC algorithm\n");
365                 return -1;
366         }
367
368         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
369                 printerr(0, "Invalid crypt algorithm\n");
370                 return -1;
371         }
372
373         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
374                 /* Try to limit key expiration to some reasonable minimum and
375                  * also prevent values over INT_MAX because there appears
376                  * to be a type conversion issue */
377                 printerr(0, "Invalid expiration time should be between %d "
378                          "and %d\n", 60, INT_MAX);
379                 return -1;
380         }
381         if (config->skc_prime_bits % 8 != 0 ||
382             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
383                 printerr(0, "Invalid session key length must be a multiple of 8"
384                          " and less then %d bits\n",
385                          SK_MAX_P_BYTES * 8);
386                 return -1;
387         }
388         if (config->skc_shared_keylen % 8 != 0 ||
389             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
390                 printerr(0, "Invalid shared key max length must be a multiple "
391                          "of 8 and less then %d bits\n",
392                          SK_MAX_KEYLEN_BYTES * 8);
393                 return -1;
394         }
395
396         /* Check for terminating nulls on strings */
397         for (i = 0; i < sizeof(config->skc_fsname) &&
398              config->skc_fsname[i] != '\0';  i++)
399                 ; /* empty loop */
400         if (i == sizeof(config->skc_fsname)) {
401                 printerr(0, "File system name not null terminated\n");
402                 return -1;
403         }
404
405         for (i = 0; i < sizeof(config->skc_nodemap) &&
406              config->skc_nodemap[i] != '\0';  i++)
407                 ; /* empty loop */
408         if (i == sizeof(config->skc_nodemap)) {
409                 printerr(0, "Nodemap name not null terminated\n");
410                 return -1;
411         }
412
413         if (config->skc_type == SK_TYPE_INVALID) {
414                 printerr(0, "Invalid key type\n");
415                 return -1;
416         }
417
418         return 0;
419 }
420
421 /**
422  * Hashes \a string and places the hash in \a hash
423  * at \a hash
424  *
425  * \param[in]           string          Null terminated string to hash
426  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
427  * \param[in,out]       hash            gss_buffer_desc to hold the result
428  *
429  * \return      -1      failure
430  * \return      0       success
431  */
432 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
433                           gss_buffer_desc *hash)
434 {
435         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
436         size_t len = strlen(string);
437         unsigned int hashlen;
438
439         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
440                 goto out_err;
441         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
442                 goto out_err;
443         if (!EVP_DigestUpdate(ctx, string, len))
444                 goto out_err;
445         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
446                 goto out_err;
447
448         EVP_MD_CTX_destroy(ctx);
449         hash->length = hashlen;
450         return 0;
451
452 out_err:
453         EVP_MD_CTX_destroy(ctx);
454         return -1;
455 }
456
457 /**
458  * Hashes \a string and verifies the resulting hash matches the value
459  * in \a current_hash
460  *
461  * \param[in]           string          Null terminated string to hash
462  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
463  * \param[in,out]       current_hash    gss_buffer_desc to compare to
464  *
465  * \return      gss error       failure
466  * \return      GSS_S_COMPLETE  success
467  */
468 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
469                         const gss_buffer_desc *current_hash)
470 {
471         gss_buffer_desc hash;
472         unsigned char hashbuf[EVP_MAX_MD_SIZE];
473
474         hash.value = hashbuf;
475         hash.length = sizeof(hashbuf);
476
477         if (sk_hash_string(string, hash_alg, &hash))
478                 return GSS_S_FAILURE;
479         if (current_hash->length != hash.length)
480                 return GSS_S_DEFECTIVE_TOKEN;
481         if (memcmp(current_hash->value, hash.value, hash.length))
482                 return GSS_S_BAD_SIG;
483
484         return GSS_S_COMPLETE;
485 }
486
487 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
488                                        const char *mgsnid)
489 {
490         lnet_nid_t nid;
491         int i;
492
493         nid = libcfs_str2nid(mgsnid);
494         if (nid == LNET_NID_ANY)
495                 return 0;
496
497         for (i = 0; i < MAX_MGSNIDS; i++)
498                 if  (config->skc_mgsnids[i] == nid)
499                         return 1;
500         return 0;
501 }
502
503 /**
504  * Create an sk_cred structure populated with initial configuration info and the
505  * key.  \a tgt and \a nodemap are used in determining the expected key
506  * description so the key can be found by searching the keyring.
507  * This is done because there is no easy way to pass keys from the mount command
508  * all the way to the request_key call.  In addition any keys can be dynamically
509  * added to the keyrings and still found.  The keyring that needs to be used
510  * must be the session keyring.
511  *
512  * \param[in]   tgt             Target file system
513  * \param[in]   nodemap         Cluster name for the key.  This correlates to
514  *                              the nodemap name and is used by the server side.
515  *                              For the client this will be NULL.
516  * \param[in]   flags           Flags for the credentials
517  *
518  * \return      sk_cred Allocated struct sk_cred on success
519  * \return      NULL    failure
520  */
521 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
522                                const uint32_t flags)
523 {
524         struct sk_keyfile_config *config;
525         struct sk_kernel_ctx *kctx;
526         struct sk_cred *skc = NULL;
527         char description[SK_DESCRIPTION_SIZE + 1];
528         char fsname[MTI_NAME_MAXLEN + 1];
529         const char *mgsnid = NULL;
530         char *ptr;
531         long sk_key;
532         int keylen;
533         int len;
534         int rc;
535
536         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
537                  tgt, nodemap);
538
539         memset(description, 0, sizeof(description));
540         memset(fsname, 0, sizeof(fsname));
541
542         /* extract the file system name from target */
543         ptr = index(tgt, '-');
544         if (!ptr) {
545                 len = strlen(tgt);
546
547                 /* This must be an MGC target */
548                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
549                         printerr(0, "Invalid target name\n");
550                         return NULL;
551                 }
552                 mgsnid = tgt + 3;
553         } else {
554                 len = ptr - tgt;
555         }
556
557         if (len > MTI_NAME_MAXLEN) {
558                 printerr(0, "Invalid target name\n");
559                 return NULL;
560         }
561         memcpy(fsname, tgt, len);
562
563         if (nodemap) {
564                 if (mgsnid)
565                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
566                                       "lustre:MGS:%s", nodemap);
567                 else
568                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
569                                       "lustre:%s:%s", fsname, nodemap);
570         } else {
571                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
572                               fsname);
573         }
574
575         if (rc >= SK_DESCRIPTION_SIZE) {
576                 printerr(0, "Invalid key description\n");
577                 return NULL;
578         }
579
580         /* It may be a good idea to move Lustre keys to the gss_keyring
581          * (lgssc) type so that they expire when Lustre modules are removed.
582          * Unfortunately it can't be done at mount time because the mount
583          * syscall could trigger the Lustre modules to load and until that
584          * point we don't have a lgssc key type.
585          *
586          * TODO: Query the community for a consensus here  */
587         printerr(2, "Searching for key with description: %s\n", description);
588         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
589                                description, 0);
590         if (sk_key == -1) {
591                 printerr(1, "No key found for %s\n", description);
592                 return NULL;
593         }
594
595         keylen = keyctl_read_alloc(sk_key, (void **)&config);
596         if (keylen == -1) {
597                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
598                          strerror(errno));
599                 return NULL;
600         } else if (keylen != sizeof(*config)) {
601                 printerr(0, "Unexpected key size: %d returned for key %ld, "
602                          "expected %zu bytes\n",
603                          keylen, sk_key, sizeof(*config));
604                 goto out_err;
605         }
606
607         sk_config_disk_to_cpu(config);
608
609         if (sk_validate_config(config)) {
610                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
611                 goto out_err;
612         }
613
614         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
615                 printerr(0, "Target name does not match key's MGS NIDs\n");
616                 goto out_err;
617         }
618
619         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
620                 printerr(0, "Target name does not match key's file system\n");
621                 goto out_err;
622         }
623
624         skc = malloc(sizeof(*skc));
625         if (!skc) {
626                 printerr(0, "Failed to allocate memory for sk_cred\n");
627                 goto out_err;
628         }
629
630         /* this initializes all gss_buffer_desc to empty as well */
631         memset(skc, 0, sizeof(*skc));
632
633         skc->sc_flags = flags;
634         skc->sc_tgt.length = strlen(tgt) + 1;
635         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
636         if (!skc->sc_tgt.value) {
637                 printerr(0, "Failed to allocate memory for target\n");
638                 goto out_err;
639         }
640         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
641
642         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
643         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
644         if (!skc->sc_nodemap_hash.value) {
645                 printerr(0, "Failed to allocate memory for nodemap hash\n");
646                 goto out_err;
647         }
648
649         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
650                            &skc->sc_nodemap_hash)) {
651                 printerr(0, "Failed to generate hash for nodemap name\n");
652                 goto out_err;
653         }
654
655         kctx = &skc->sc_kctx;
656         kctx->skc_version = config->skc_version;
657         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
658         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
659         kctx->skc_expire = config->skc_expire;
660
661         /* key payload format is in bits, convert to bytes */
662         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
663         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
664         if (!kctx->skc_shared_key.value) {
665                 printerr(0, "Failed to allocate memory for shared key\n");
666                 goto out_err;
667         }
668         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
669                kctx->skc_shared_key.length);
670
671         skc->sc_p.length = config->skc_prime_bits / 8;
672         skc->sc_p.value = malloc(skc->sc_p.length);
673         if (!skc->sc_p.value) {
674                 printerr(0, "Failed to allocate p\n");
675                 goto out_err;
676         }
677         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
678
679         free(config);
680
681         return skc;
682
683 out_err:
684         sk_free_cred(skc);
685
686         free(config);
687         return NULL;
688 }
689
690 #define SK_GENERATOR 2
691 #define DH_NUMBER_ITERATIONS_FOR_PRIME 64
692
693 /* OpenSSL 1.1.1c increased the number of rounds used for Miller-Rabin testing
694  * of the prime provided as input parameter to DH_check(). This makes the check
695  * roughly x10 longer, and causes request timeouts when an SSK flavor is being
696  * used.
697  * Instead, use a dynamic number Miller-Rabin rounds based on the speed of the
698  * check on the current system, evaluated when the lsvcgssd daemon starts, but
699  * at least as many as OpenSSL 1.1.1b used for the same key size. If default
700  * DH_check() duration is OK, use it directly instead of limiting the rounds.
701  * If \a num_rounds == 0, we just call original DH_check() directly.
702  *
703  * OpenSSL v3 internally forces a minimum of 64 rounds when checking prime, so
704  * it is no longer possible to test prime check speed with fewer rounds. In this
705  * case, do not bother and directly call EVP_PKEY_param_check.
706  */
707 #ifdef HAVE_OPENSSL_EVP_PKEY
708 static bool sk_is_dh_valid(EVP_PKEY_CTX *ctx)
709 {
710         if (EVP_PKEY_param_check(ctx) != 1) {
711                 printerr(0, "EVP_PKEY_param_check failed\n");
712                 ERR_print_errors_fp(stderr);
713                 return false;
714         }
715         return true;
716 }
717 #else
718 static inline bool sk_check_dh(const DH *dh, int num_rounds, bool fullcheck)
719 {
720         const BIGNUM *p = NULL, *g = NULL;
721         BN_ULONG word;
722         BN_CTX *ctx;
723         BIGNUM *r;
724         bool valid = false;
725         int rc;
726
727         DH_get0_pqg(dh, &p, NULL, &g);
728         if (!p || !g)
729                 return false;
730
731         if (!BN_is_word(g, SK_GENERATOR)) {
732                 printerr(0, "%s: Diffie-Hellman generator is not %u\n",
733                          program_invocation_short_name, SK_GENERATOR);
734                 return false;
735         }
736
737         word = BN_mod_word(p, 24);
738         /* OpenSSL v3 changed the way the prime is generated,
739          * using p mod 24 == 23.
740          * So we must accept word == 23 if the prime was generated
741          * by a client with OpenSSL v3.
742          */
743         if ((word != 11) && (word != 23)) {
744                 printerr(0, "%s: Diffie-Hellman prime modulo=%lu unsuitable\n",
745                          program_invocation_short_name, word);
746                 return false;
747         }
748
749         if (!fullcheck)
750                 return true;
751
752         ctx = BN_CTX_new();
753         if (ctx == NULL) {
754                 printerr(0, "%s: Diffie-Hellman error allocating context\n",
755                          program_invocation_short_name);
756                 return false;
757         }
758         BN_CTX_start(ctx);
759         r = BN_CTX_get(ctx); /* must be called before "ctx" used elsewhere */
760
761         rc = BN_is_prime_ex(p, num_rounds, ctx, NULL);
762         if (rc == 0)
763                 printerr(0, "%s: Diffie-Hellman 'p' not prime in %u rounds\n",
764                          program_invocation_short_name, num_rounds);
765         if (rc <= 0)
766                 goto out_free;
767
768         if (!BN_rshift1(r, p)) {
769                 printerr(0, "%s: error shifting BigNum 'r' by 'p'\n",
770                          program_invocation_short_name);
771                 goto out_free;
772         }
773         rc = BN_is_prime_ex(r, num_rounds, ctx, NULL);
774         if (rc == 0)
775                 printerr(0, "%s: Diffie-Hellman 'r' not prime in %u rounds\n",
776                          program_invocation_short_name, num_rounds);
777         if (rc <= 0)
778                 goto out_free;
779
780         valid = true;
781
782 out_free:
783         BN_CTX_end(ctx);
784         BN_CTX_free(ctx);
785
786         return valid;
787 }
788
789 static bool sk_is_dh_valid(const DH *dh, int num_rounds)
790 {
791         int rc;
792
793         if (num_rounds == 0) {
794                 int codes = 0;
795
796                 rc = DH_check(dh, &codes);
797                 if (codes == DH_NOT_SUITABLE_GENERATOR &&
798                     sk_check_dh(dh, num_rounds, false))
799                         return true;
800                 if (rc != 1 || codes) {
801                         printerr(0, "DH_check(0) failed: codes=%#x: rc=%d\n",
802                                  codes, rc);
803                         return false;
804                 }
805                 return true;
806         }
807
808         return sk_check_dh(dh, num_rounds, true);
809 }
810 #endif
811
812 #ifndef HAVE_OPENSSL_EVP_PKEY
813 #define VALUE_LENGTH 256
814 static unsigned char test_prime[VALUE_LENGTH] =
815         "\xf7\xfa\x49\xd8\xec\xb1\x3b\xff\x26\x10\x3f\xc5\x3a\xc5\xcc\x40"
816         "\x4f\xbf\x92\xe1\x8b\x83\xe7\xa2\xba\x0f\x51\x5a\x91\x48\xe0\xa3"
817         "\xf1\x4d\xbc\xbb\x8a\x28\x14\xac\x02\x23\x76\x42\x17\x4d\x3c\xdc"
818         "\x5e\x4f\x80\x1f\xd7\x54\x1c\x50\xac\x3b\x28\x68\x8d\x71\x41\x7f"
819         "\xa7\x1c\x2f\x22\xd3\xa8\x91\xb2\x64\xb6\x84\xa6\xcf\x06\x16\x91"
820         "\x2f\xb8\xb4\x42\x1d\x3a\x4e\x3a\x0c\x7f\x04\x69\x78\xb5\x8f\x92"
821         "\x07\x89\xac\x24\x06\x53\x2c\x23\xec\xaa\x5c\xb4\x7b\x49\xbc\xf4"
822         "\x90\x67\x71\x9c\x24\x2c\x1d\x8d\x76\xc8\x85\x4e\x19\xf1\xf9\x33"
823         "\x45\xbd\x9f\x7d\x0a\x08\x8c\x22\xcc\x35\xf3\x5b\xab\x3f\x24\x9d"
824         "\x61\x70\x86\xbb\xbe\xd8\xb0\xf8\x34\xfa\xeb\x5b\x8e\xf2\x62\x23"
825         "\xd1\xfb\xbb\xb8\x21\x71\x1e\x39\x39\x59\xe0\x82\x98\x41\x84\x40"
826         "\x1f\xd3\x9b\xa3\x73\xdb\xec\xe0\xc0\xde\x2d\x1c\xea\x43\x40\x93"
827         "\x98\x38\x03\x36\x1e\xe1\xe7\x39\x7b\x35\x92\x4a\x51\xa5\x91\x63"
828         "\xd5\x31\x98\x3d\x89\x27\x6f\xcc\x69\xff\xbe\x31\x13\xdc\x2f\x72"
829         "\x2d\xab\x6a\xb7\x13\xd3\x47\xda\xaa\xf3\x3c\xa0\xfd\xaa\x0f\x02"
830         "\x96\x81\x1a\x26\xe8\xf7\x25\x65\x33\x78\xd9\x6b\x6d\xb0\xd9\xfb";
831
832 /**
833  * Measure time taken by prime testing routine for a 2048 bit long prime,
834  * depending on the number of check rounds.
835  *
836  * \param[in]   usec_check_max    max time allowed for DH_check completion
837  *
838  * \retval      max number of rounds to keep prime testing under usec_check_max
839  *              return 0 if we should use the default DH_check rounds
840  */
841 int sk_speedtest_dh_valid(unsigned int usec_check_max)
842 {
843         DH *dh;
844         BIGNUM *p, *g;
845         int num_rounds, prev_rounds = 0;
846
847         dh = DH_new();
848         if (!dh)
849                 return 0;
850
851         p = BN_bin2bn(test_prime, VALUE_LENGTH, NULL);
852         if (!p)
853                 goto free_dh;
854
855         g = BN_new();
856         if (!g)
857                 goto free_p;
858
859         if (!BN_set_word(g, SK_GENERATOR))
860                 goto free_g;
861
862         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
863         if (!DH_set0_pqg(dh, p, NULL, g)) {
864         free_g:
865                 BN_free(g);
866         free_p:
867                 BN_free(p);
868                 goto free_dh;
869         }
870
871         for (num_rounds = 0;
872              num_rounds <= DH_NUMBER_ITERATIONS_FOR_PRIME;
873              num_rounds += (num_rounds <= 4 ? 4 : 8)) {
874                 unsigned int usec_this;
875                 int j;
876
877                 /* get max duration of 4 runs at current number of rounds */
878                 usec_this = 0;
879                 for (j = 0; j < 4; j++) {
880                         struct timeval now, prev;
881                         unsigned int usec_curr;
882
883                         gettimeofday(&prev, NULL);
884                         if (!sk_is_dh_valid(dh, num_rounds)) {
885                                 /* if test_prime is found bad, use default */
886                                 prev_rounds = 0;
887                                 goto free_dh;
888                         }
889                         gettimeofday(&now, NULL);
890                         usec_curr = (now.tv_sec - prev.tv_sec) * 1000000 +
891                                     now.tv_usec - prev.tv_usec;
892                         if (usec_curr > usec_this)
893                                 usec_this = usec_curr;
894                 }
895                 printerr(2, "%s: %d rounds: %d usec\n",
896                          program_invocation_short_name, num_rounds, usec_this);
897                 if (num_rounds == 0) {
898                         if (usec_this <= usec_check_max)
899                         /* using original check rounds as implemented in
900                          * DH_check() took less time than the max allowed,
901                          * so just use original DH_check()
902                          */
903                                 break;
904                 } else if (usec_this >= usec_check_max) {
905                         break;
906                 }
907                 prev_rounds = num_rounds;
908         }
909
910 free_dh:
911         DH_free(dh);
912
913         return prev_rounds;
914 }
915 #endif /* !HAVE_OPENSSL_EVP_PKEY */
916
917 #ifdef HAVE_OPENSSL_EVP_PKEY
918 static uint32_t __sk_gen_params(struct sk_cred *skc, BIGNUM *p, BIGNUM *g,
919                                 int num_rounds)
920 {
921         EVP_PKEY_CTX *ctx = NULL, *ctx_from_key = NULL;
922         OSSL_PARAM_BLD *tmpl = NULL;
923         OSSL_PARAM *params = NULL;
924         EVP_PKEY *key = NULL;
925         uint32_t rc = GSS_S_FAILURE;
926
927         tmpl = OSSL_PARAM_BLD_new();
928         if (!tmpl ||
929             !OSSL_PARAM_BLD_push_BN(tmpl, OSSL_PKEY_PARAM_FFC_P, p) ||
930             !OSSL_PARAM_BLD_push_BN(tmpl, OSSL_PKEY_PARAM_FFC_G, g)) {
931                 printerr(0, "error: params cannot be pushed\n");
932                 goto err;
933         }
934         params = OSSL_PARAM_BLD_to_param(tmpl);
935         if (!params) {
936                 printerr(0, "error: params cannot be allocated\n");
937                 goto err;
938         }
939
940         ctx = EVP_PKEY_CTX_new_from_name(NULL, "DH", NULL);
941         if (!ctx ||
942             EVP_PKEY_fromdata_init(ctx) != 1 ||
943             EVP_PKEY_fromdata(ctx, &key,
944                               EVP_PKEY_KEY_PARAMETERS, params) != 1) {
945                 printerr(0, "error: params cannot be set\n");
946                 goto err;
947         }
948
949         ctx_from_key = EVP_PKEY_CTX_new_from_pkey(NULL, key, NULL);
950         if (!ctx_from_key) {
951                 printerr(0, "error: ctx_from_key cannot be allocated\n");
952                 goto err;
953         }
954
955         /* Verify that we have a safe prime and valid generator */
956         if (!sk_is_dh_valid(ctx_from_key))
957                 goto err;
958
959         skc->sc_params = NULL;
960         if (EVP_PKEY_keygen_init(ctx_from_key) != 1 ||
961             EVP_PKEY_keygen(ctx_from_key, &skc->sc_params) != 1) {
962                 printerr(0, "Failed to generate public DH key: %s\n",
963                          ERR_error_string(ERR_get_error(), NULL));
964                 goto err;
965         }
966
967         /* skc->sc_pub_key.value is allocated by
968          * EVP_PKEY_get1_encoded_public_key
969          */
970         skc->sc_pub_key.length =
971           EVP_PKEY_get1_encoded_public_key(skc->sc_params,
972                                       (unsigned char **)&skc->sc_pub_key.value);
973         if (skc->sc_pub_key.length == 0) {
974                 printerr(0, "error: cannot get pub key\n");
975                 skc->sc_pub_key.value = NULL;
976                 goto err;
977         }
978         rc = GSS_S_COMPLETE;
979
980 err:
981         EVP_PKEY_CTX_free(ctx_from_key);
982         EVP_PKEY_free(key);
983         EVP_PKEY_CTX_free(ctx);
984         OSSL_PARAM_free(params);
985         OSSL_PARAM_BLD_free(tmpl);
986         BN_free(g);
987         BN_free(p);
988         return rc;
989 }
990 #else /* !HAVE_OPENSSL_EVP_PKEY */
991 static uint32_t __sk_gen_params(struct sk_cred *skc, BIGNUM *p, BIGNUM *g,
992                                 int num_rounds)
993 {
994         const BIGNUM *pub_key;
995
996         /* Populate DH parameters */
997         /* "dh" takes over freeing of 'p' and 'g' if this succeeds */
998         skc->sc_params = DH_new();
999         if (!skc->sc_params || !DH_set0_pqg(skc->sc_params, p, NULL, g)) {
1000                 printerr(0, "Failed to set pqg\n");
1001                 BN_free(g);
1002                 BN_free(p);
1003                 return GSS_S_FAILURE;
1004         }
1005
1006         /* Verify that we have a safe prime and valid generator */
1007         if (!sk_is_dh_valid(skc->sc_params, num_rounds))
1008                 return GSS_S_FAILURE;
1009
1010         if (DH_generate_key(skc->sc_params) != 1) {
1011                 printerr(0, "Failed to generate public DH key: %s\n",
1012                          ERR_error_string(ERR_get_error(), NULL));
1013                 return GSS_S_FAILURE;
1014         }
1015
1016         DH_get0_key(skc->sc_params, &pub_key, NULL);
1017         skc->sc_pub_key.length = BN_num_bytes(pub_key);
1018         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
1019         if (!skc->sc_pub_key.value) {
1020                 printerr(0, "Failed to allocate memory for public key\n");
1021                 return GSS_S_FAILURE;
1022         }
1023
1024         BN_bn2bin(pub_key, skc->sc_pub_key.value);
1025
1026         return GSS_S_COMPLETE;
1027 }
1028 #endif /* HAVE_OPENSSL_EVP_PKEY */
1029
1030 /**
1031  * Populates the DH parameters for the DHKE
1032  *
1033  * \param[in,out]       skc             Shared key credentials structure to
1034  *                                      populate with DH parameters
1035  *
1036  * \retval      GSS_S_COMPLETE  success
1037  * \retval      GSS_S_FAILURE   failure
1038  */
1039 uint32_t sk_gen_params(struct sk_cred *skc, int num_rounds)
1040 {
1041         uint32_t random;
1042         BIGNUM *p, *g;
1043
1044         /* Random value used by both the request and response as part of the
1045          * key binding material.  This also should ensure we have unqiue
1046          * tokens that are sent to the remote server which is important because
1047          * the token is hashed for the sunrpc cache lookups and a failure there
1048          * would cause connection attempts to fail indefinitely due to the large
1049          * timeout value on the server side.
1050          */
1051         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
1052                 printerr(0, "Failed to get data for random parameter: %s\n",
1053                          ERR_error_string(ERR_get_error(), NULL));
1054                 return GSS_S_FAILURE;
1055         }
1056
1057         /* The random value will always be used in byte range operations
1058          * so we keep it as big endian from this point on.
1059          */
1060         skc->sc_kctx.skc_host_random = random;
1061
1062         p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
1063         if (!p) {
1064                 printerr(0, "Failed to convert binary to BIGNUM\n");
1065                 return GSS_S_FAILURE;
1066         }
1067
1068         /* We use a static generator for shared key */
1069         g = BN_new();
1070         if (!g) {
1071                 printerr(0, "Failed to allocate new BIGNUM\n");
1072                 goto free_p;
1073         }
1074         if (BN_set_word(g, SK_GENERATOR) != 1) {
1075                 printerr(0, "Failed to set g value for DH params\n");
1076                 goto free_g;
1077         }
1078
1079         return __sk_gen_params(skc, p, g, num_rounds);
1080
1081 free_g:
1082         BN_free(g);
1083 free_p:
1084         BN_free(p);
1085
1086         return GSS_S_FAILURE;
1087 }
1088
1089 /**
1090  * Convert SK hash algorithm into openssl message digest
1091  *
1092  * \param[in,out]       alg             SK hash algorithm
1093  *
1094  * \retval              EVP_MD
1095  */
1096 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
1097 {
1098         switch (alg) {
1099         case CFS_HASH_ALG_SHA256:
1100                 return EVP_sha256();
1101         case CFS_HASH_ALG_SHA512:
1102                 return EVP_sha512();
1103         default:
1104                 return EVP_md_null();
1105         }
1106 }
1107
1108 /**
1109  * Signs (via HMAC) the parameters used only in the key initialization protocol.
1110  *
1111  * \param[in]           key             Key to use for HMAC
1112  * \param[in]           bufs            Array of gss_buffer_desc to generate
1113  *                                      HMAC for
1114  * \param[in]           numbufs         Number of buffers in array
1115  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
1116  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
1117  *                                      in this gss_buffer_desc.  Caller must
1118  *                                      free this.
1119  *
1120  * \retval      0       success
1121  * \retval      -1      failure
1122  */
1123 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
1124                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
1125 {
1126         unsigned int hashlen = EVP_MD_size(hash_alg);
1127         EVP_MAC_CTX *ctx = NULL;
1128         EVP_MAC *mac = NULL;
1129         size_t len = 0;
1130         int i, rc = -1;
1131         DECLARE_EVP_MD(subalg, hash_alg);
1132
1133         if (hash_alg == EVP_md_null()) {
1134                 printerr(0, "Invalid hash algorithm\n");
1135                 return -1;
1136         }
1137
1138         hmac->length = hashlen;
1139         hmac->value = malloc(hashlen);
1140         if (!hmac->value) {
1141                 printerr(0, "Failed to allocate memory for HMAC\n");
1142                 goto out;
1143         }
1144
1145         mac = EVP_MAC_fetch(NULL, "HMAC", NULL);
1146         if (!mac) {
1147                 printerr(0, "Failed to fetch HMAC\n");
1148                 goto out;
1149         }
1150
1151         ctx = EVP_MAC_CTX_new(mac);
1152         if (!ctx) {
1153                 printerr(0, "Failed to init HMAC ctx\n");
1154                 goto out;
1155         }
1156
1157         if (EVP_MAC_init(ctx, key->value, key->length, subalg) != 1) {
1158                 printerr(0, "Failed to init HMAC\n");
1159                 goto out;
1160         }
1161
1162         for (i = 0; i < numbufs; i++) {
1163                 if (EVP_MAC_update(ctx, bufs[i].value, bufs[i].length) != 1) {
1164                         printerr(0, "Failed to update HMAC\n");
1165                         goto out;
1166                 }
1167         }
1168
1169         /* The result gets populated in hmac */
1170         if (EVP_MAC_final(ctx, hmac->value, &len, hashlen) != 1) {
1171                 printerr(0, "Failed to finalize HMAC\n");
1172                 goto out;
1173         }
1174         if (hmac->length != len) {
1175                 printerr(0, "HMAC size %zu does not match expected %zu\n",
1176                          len, hmac->length);
1177                 goto out;
1178         }
1179
1180         rc = 0;
1181 out:
1182         EVP_MAC_CTX_free(ctx);
1183         EVP_MAC_free(mac);
1184         return rc;
1185 }
1186
1187 /**
1188  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
1189  * and verifies against \a hmac.
1190  *
1191  * \param[in]   skc             Shared key credentials
1192  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
1193  * \param[in]   numbufs         Number of buffers in array
1194  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
1195  * \param[in]   hmac            HMAC to verify against
1196  *
1197  * \retval      GSS_S_COMPLETE  success (match)
1198  * \retval      gss error       failure
1199  */
1200 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
1201                         const int numbufs, const EVP_MD *hash_alg,
1202                         gss_buffer_desc *hmac)
1203 {
1204         gss_buffer_desc bufs_hmac;
1205         int rc;
1206
1207         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
1208                          &bufs_hmac)) {
1209                 printerr(0, "Failed to sign buffers to verify HMAC\n");
1210                 if (bufs_hmac.value)
1211                         free(bufs_hmac.value);
1212                 return GSS_S_FAILURE;
1213         }
1214
1215         if (hmac->length != bufs_hmac.length) {
1216                 printerr(0, "Invalid HMAC size\n");
1217                 free(bufs_hmac.value);
1218                 return GSS_S_BAD_SIG;
1219         }
1220
1221         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
1222         free(bufs_hmac.value);
1223
1224         if (rc)
1225                 return GSS_S_BAD_SIG;
1226
1227         return GSS_S_COMPLETE;
1228 }
1229
1230 /**
1231  * Cleanup an sk_cred freeing any resources
1232  *
1233  * \param[in,out]       skc     Shared key credentials to free
1234  */
1235 void sk_free_cred(struct sk_cred *skc)
1236 {
1237         if (!skc)
1238                 return;
1239
1240         if (skc->sc_p.value)
1241                 free(skc->sc_p.value);
1242         if (skc->sc_pub_key.value)
1243                 free(skc->sc_pub_key.value);
1244         if (skc->sc_tgt.value)
1245                 free(skc->sc_tgt.value);
1246         if (skc->sc_nodemap_hash.value)
1247                 free(skc->sc_nodemap_hash.value);
1248         if (skc->sc_hmac.value)
1249                 free(skc->sc_hmac.value);
1250
1251         /* Overwrite keys and IV before freeing */
1252         if (skc->sc_dh_shared_key.value) {
1253                 memset(skc->sc_dh_shared_key.value, 0,
1254                        skc->sc_dh_shared_key.length);
1255                 free(skc->sc_dh_shared_key.value);
1256         }
1257         if (skc->sc_kctx.skc_hmac_key.value) {
1258                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
1259                        skc->sc_kctx.skc_hmac_key.length);
1260                 free(skc->sc_kctx.skc_hmac_key.value);
1261         }
1262         if (skc->sc_kctx.skc_encrypt_key.value) {
1263                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
1264                        skc->sc_kctx.skc_encrypt_key.length);
1265                 free(skc->sc_kctx.skc_encrypt_key.value);
1266         }
1267         if (skc->sc_kctx.skc_shared_key.value) {
1268                 memset(skc->sc_kctx.skc_shared_key.value, 0,
1269                        skc->sc_kctx.skc_shared_key.length);
1270                 free(skc->sc_kctx.skc_shared_key.value);
1271         }
1272         if (skc->sc_kctx.skc_session_key.value) {
1273                 memset(skc->sc_kctx.skc_session_key.value, 0,
1274                        skc->sc_kctx.skc_session_key.length);
1275                 free(skc->sc_kctx.skc_session_key.value);
1276         }
1277
1278         if (skc->sc_params) {
1279                 EVP_PKEY_free(skc->sc_params);
1280                 skc->sc_params = NULL;
1281         }
1282
1283         free(skc);
1284         skc = NULL;
1285 }
1286
1287 /* This function handles key derivation using the hash algorithm specified in
1288  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
1289  * \a origin_key to produce a \a derived_key.  The first element of the
1290  * key_binding_bufs array is reserved for the counter used in the KDF.  The
1291  * derived key in \a derived_key could differ in size from \a origin_key and
1292  * must be populated with the expected size and a valid buffer to hold the
1293  * contents.
1294  *
1295  * If the derived key size is greater than the HMAC algorithm size it will be
1296  * a done using several iterations of a counter and the key binding bufs.
1297  *
1298  * If the size is smaller it will take copy the first N bytes necessary to
1299  * fill the derived key. */
1300 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
1301            gss_buffer_desc *key_binding_bufs, int numbufs,
1302            enum cfs_crypto_hash_alg hmac_alg)
1303 {
1304         size_t remain;
1305         size_t bytes;
1306         uint32_t counter;
1307         char *keydata;
1308         gss_buffer_desc tmp_hash;
1309         int i;
1310         int rc;
1311
1312         if (numbufs < 1)
1313                 return -EINVAL;
1314
1315         /* Use a counter as the first buffer followed by the key binding
1316          * buffers in the event we need more than one a single cycle to
1317          * produced a symmetric key large enough in size */
1318         key_binding_bufs[0].value = &counter;
1319         key_binding_bufs[0].length = sizeof(counter);
1320
1321         remain = derived_key->length;
1322         keydata = derived_key->value;
1323         i = 0;
1324         while (remain > 0) {
1325                 counter = htobe32(i++);
1326                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1327                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1328                 if (rc) {
1329                         if (tmp_hash.value)
1330                                 free(tmp_hash.value);
1331                         return rc;
1332                 }
1333
1334                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
1335                         free(tmp_hash.value);
1336                         return -EINVAL;
1337                 }
1338
1339                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1340                 memcpy(keydata, tmp_hash.value, bytes);
1341                 free(tmp_hash.value);
1342                 remain -= bytes;
1343                 keydata += bytes;
1344         }
1345
1346         return 0;
1347 }
1348
1349 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1350  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1351  * (Sep 2014) Section 5.5.1
1352  *
1353  * \param[in,out]       skc             Shared key credentials structure with
1354  *
1355  * \return      -1              failure
1356  * \return      0               success
1357  */
1358 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1359                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1360 {
1361         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1362         gss_buffer_desc *session_key = &kctx->skc_session_key;
1363         gss_buffer_desc bufs[5];
1364         enum cfs_crypto_crypt_alg crypt_alg;
1365         int rc = -1;
1366
1367         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1368         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1369         session_key->value = malloc(session_key->length);
1370         if (!session_key->value) {
1371                 printerr(0, "Failed to allocate memory for session key\n");
1372                 return rc;
1373         }
1374
1375         /* Key binding info ordering
1376          * 1. Reserved for counter
1377          * 1. DH shared key
1378          * 2. Client's NIDs
1379          * 3. Client's token
1380          * 4. Server's token */
1381         bufs[0].value = NULL;
1382         bufs[0].length = 0;
1383         bufs[1] = skc->sc_dh_shared_key;
1384         bufs[2].value = &client_nid;
1385         bufs[2].length = sizeof(client_nid);
1386         bufs[3] = *client_token;
1387         bufs[4] = *server_token;
1388
1389         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1390                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1391 }
1392
1393 /* Uses the session key to create an HMAC key and encryption key.  In
1394  * integrity mode the session key used to generate the HMAC key uses
1395  * session information which is available on the wire but by creating
1396  * a session based HMAC key we can prevent potential replay as both the
1397  * client and server have random numbers used as part of the key creation.
1398  *
1399  * The keys used for integrity and privacy are formulated as below using
1400  * the session key that is the output of the key derivation function.  The
1401  * HMAC algorithm is determined by the shared key algorithm selected in the
1402  * key file.
1403  *
1404  * For ski mode:
1405  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1406  *
1407  * For skpi mode:
1408  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1409  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1410  *
1411  * \param[in,out]       skc             Shared key credentials structure with
1412  *
1413  * \return      -1              failure
1414  * \return      0               success
1415  */
1416 int sk_compute_keys(struct sk_cred *skc)
1417 {
1418         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1419         gss_buffer_desc *session_key = &kctx->skc_session_key;
1420         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1421         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1422         enum cfs_crypto_hash_alg hmac_alg;
1423         enum cfs_crypto_crypt_alg crypt_alg;
1424         char *encrypt = "Encrypt";
1425         char *integrity = "Integrity";
1426         int rc;
1427
1428         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1429         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1430         hmac_key->value = malloc(hmac_key->length);
1431         if (!hmac_key->value)
1432                 return -ENOMEM;
1433
1434         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1435                                session_key->length, SK_PBKDF2_ITERATIONS,
1436                                sk_hash_to_evp_md(hmac_alg),
1437                                hmac_key->length, hmac_key->value);
1438         if (rc == 0)
1439                 return -EINVAL;
1440
1441         /* Encryption key is only populated in privacy mode */
1442         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1443                 return 0;
1444
1445         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1446         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1447         encrypt_key->value = malloc(encrypt_key->length);
1448         if (!encrypt_key->value)
1449                 return -ENOMEM;
1450
1451         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1452                                session_key->length, SK_PBKDF2_ITERATIONS,
1453                                sk_hash_to_evp_md(hmac_alg),
1454                                encrypt_key->length, encrypt_key->value);
1455         if (rc == 0)
1456                 return -EINVAL;
1457
1458         return 0;
1459 }
1460
1461 uint32_t __sk_compute_dh_key(struct sk_cred *skc,
1462                              const gss_buffer_desc *pub_key,
1463                              size_t *expected_len)
1464 {
1465         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1466         uint32_t rc = GSS_S_FAILURE;
1467 #ifdef HAVE_OPENSSL_EVP_PKEY
1468         EVP_PKEY_CTX *ctx = NULL;
1469         EVP_PKEY *peerkey = NULL;
1470
1471         peerkey = EVP_PKEY_new();
1472         if (!peerkey ||
1473             EVP_PKEY_copy_parameters(peerkey, skc->sc_params) != 1) {
1474                 printerr(0, "error: peerkey cannot be init\n");
1475                 goto out_err;
1476         }
1477
1478         if (EVP_PKEY_set1_encoded_public_key(peerkey,
1479                                              pub_key->value,
1480                                              pub_key->length) != 1) {
1481                 printerr(0, "error: peerkey cannot be set\n");
1482                 goto out_err;
1483         }
1484
1485         ctx = EVP_PKEY_CTX_new_from_pkey(NULL, skc->sc_params, NULL);
1486         if (!ctx) {
1487                 printerr(0, "error: ctx cannot be allocated\n");
1488                 goto out_err;
1489         }
1490
1491         if (EVP_PKEY_derive_init(ctx) != 1 ||
1492             EVP_PKEY_derive_set_peer(ctx, peerkey) != 1) {
1493                 printerr(0, "error: ctx cannot be init\n");
1494                 goto out_err;
1495         }
1496
1497         if (EVP_PKEY_derive(ctx, NULL, expected_len) != 1) {
1498                 printerr(0, "error: cannot get dh length\n");
1499                 goto out_err;
1500         }
1501
1502         dh_shared->length = *expected_len;
1503         dh_shared->value = malloc(*expected_len);
1504         if (!dh_shared->value) {
1505                 printerr(0, "error: cannot allocate memory for shared key\n");
1506                 goto out_err;
1507         }
1508
1509         if (EVP_PKEY_derive(ctx, dh_shared->value, &dh_shared->length) != 1) {
1510                 printerr(0, "error: cannot derive dh key\n");
1511                 ERR_print_errors_fp(stderr);
1512                 goto out_err;
1513         }
1514
1515         rc = GSS_S_COMPLETE;
1516 out_err:
1517         EVP_PKEY_CTX_free(ctx);
1518         EVP_PKEY_free(peerkey);
1519 #else /* !HAVE_OPENSSL_EVP_PKEY */
1520         BIGNUM *remote_pub_key;
1521
1522         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1523         if (!remote_pub_key) {
1524                 printerr(0, "Failed to convert binary to BIGNUM\n");
1525                 return rc;
1526         }
1527
1528         *expected_len = DH_size(skc->sc_params);
1529         dh_shared->length = *expected_len;
1530         dh_shared->value = malloc(*expected_len);
1531         if (!dh_shared->value) {
1532                 printerr(0,
1533                          "Failed to allocate memory for computed shared secret key\n");
1534                 goto out_err;
1535         }
1536
1537         /* This computes the shared key from the DHKE */
1538         dh_shared->length = DH_compute_key(dh_shared->value, remote_pub_key,
1539                                            skc->sc_params);
1540         if (dh_shared->length == -1) {
1541                 printerr(0, "DH key derivation failed: %s\n",
1542                          ERR_error_string(ERR_get_error(), NULL));
1543                 goto out_err;
1544         }
1545
1546         rc = GSS_S_COMPLETE;
1547 out_err:
1548         BN_free(remote_pub_key);
1549 #endif /* HAVE_OPENSSL_EVP_PKEY */
1550         return rc;
1551 }
1552
1553 /**
1554  * Computes a session key based on the DH parameters from the host and its peer
1555  *
1556  * \param[in,out]       skc             Shared key credentials structure with
1557  *                                      the session key populated with the
1558  *                                      compute key
1559  * \param[in]           pub_key         Public key returned from peer in
1560  *                                      gss_buffer_desc
1561  * \return      gss error               failure
1562  * \return      GSS_S_COMPLETE          success
1563  */
1564 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1565 {
1566         size_t expected_len;
1567         uint32_t rc;
1568
1569         rc = __sk_compute_dh_key(skc, pub_key, &expected_len);
1570         if (rc != GSS_S_COMPLETE)
1571                 return rc;
1572
1573         if (skc->sc_dh_shared_key.length < expected_len) {
1574                 /* there is around 1 chance out of 256 that the returned
1575                  * shared key is shorter than expected
1576                  */
1577                 if (skc->sc_dh_shared_key.length >= expected_len - 2) {
1578                         int shift = expected_len - skc->sc_dh_shared_key.length;
1579
1580                         /* if the key is short by only 1 or 2 bytes, just
1581                          * prepend it with 0s
1582                          */
1583                         memmove((void *)(skc->sc_dh_shared_key.value + shift),
1584                                 skc->sc_dh_shared_key.value,
1585                                 skc->sc_dh_shared_key.length);
1586                         memset(skc->sc_dh_shared_key.value, 0, shift);
1587                 } else {
1588                         /* if the key is really too short, return GSS_S_BAD_QOP
1589                          * so that the caller can retry to generate
1590                          */
1591                         printerr(0,
1592                                  "DH derivation returned a short key of %zu bytes, expected: %zu\n",
1593                                  skc->sc_dh_shared_key.length, expected_len);
1594                         rc = GSS_S_BAD_QOP;
1595                 }
1596         }
1597         return rc;
1598 }
1599
1600 /**
1601  * Creates a serialized buffer for the kernel in the order of struct
1602  * sk_kernel_ctx.
1603  *
1604  * \param[in,out]       skc             Shared key credentials structure
1605  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1606  *                                      Caller must free this buffer.
1607  *
1608  * \return      0       success
1609  * \return      -1      failure
1610  */
1611 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1612 {
1613         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1614         char *p, *end;
1615         size_t bufsize;
1616
1617         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1618                   kctx->skc_encrypt_key.length;
1619
1620         ctx_token->value = malloc(bufsize);
1621         if (!ctx_token->value)
1622                 return -1;
1623         ctx_token->length = bufsize;
1624
1625         p = ctx_token->value;
1626         end = p + ctx_token->length;
1627
1628         if (WRITE_BYTES(&p, end, kctx->skc_version))
1629                 return -1;
1630         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1631                 return -1;
1632         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1633                 return -1;
1634         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1635                 return -1;
1636         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1637                 return -1;
1638         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1639                 return -1;
1640         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1641                 return -1;
1642         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1643                 return -1;
1644
1645         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1646
1647         return 0;
1648 }
1649
1650 /**
1651  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1652  * up to \a numbufs.  Memory is allocated for each value and length
1653  * will be populated with the length
1654  *
1655  * \param[in,out]       bufs    Array of gss_buffer_descs
1656  * \param[in,out]       numbufs number of gss_buffer_desc in array
1657  * \param[in]           ns      netstring to decode
1658  *
1659  * \return      buffers populated       success
1660  * \return      -1                      failure
1661  */
1662 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1663 {
1664         char *ptr = ns->value;
1665         size_t remain = ns->length;
1666         unsigned int size;
1667         int digits;
1668         int sep;
1669         int rc;
1670         int i;
1671
1672         for (i = 0; i < numbufs; i++) {
1673                 /* read the size of first buffer */
1674                 rc = sscanf(ptr, "%9u", &size);
1675                 if (rc < 1)
1676                         goto out_err;
1677                 digits = (size) ? ceil(log10(size + 1)) : 1;
1678
1679                 /* sep of current string */
1680                 sep = size + digits + 2;
1681
1682                 /* check to make sure it's valid */
1683                 if (remain < sep || ptr[digits] != ':' ||
1684                     ptr[sep - 1] != ',')
1685                         goto out_err;
1686
1687                 bufs[i].length = size;
1688                 if (size == 0) {
1689                         bufs[i].value = NULL;
1690                 } else {
1691                         bufs[i].value = malloc(size);
1692                         if (!bufs[i].value)
1693                                 goto out_err;
1694                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1695                 }
1696
1697                 remain -= sep;
1698                 ptr += sep;
1699         }
1700
1701         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1702         return i;
1703
1704 out_err:
1705         while (i-- > 0) {
1706                 if (bufs[i].value)
1707                         free(bufs[i].value);
1708                 bufs[i].length = 0;
1709         }
1710         return -1;
1711 }
1712
1713 /**
1714  * Creates a netstring in a gss_buffer_desc that consists of all
1715  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1716  * as binary as it can contain null characters.
1717  *
1718  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1719  * \param[in]           numbufs         Number of buffers in array
1720  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1721  *                                      netstring
1722  *
1723  * \return      -1      failure
1724  * \return      0       success
1725  */
1726 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1727                         gss_buffer_desc *ns)
1728 {
1729         unsigned char *ptr;
1730         int size = 0;
1731         int rc;
1732         int i;
1733
1734         /* size of string in decimal, string size, colon, and comma */
1735         for (i = 0; i < numbufs; i++) {
1736
1737                 if (bufs[i].length == 0)
1738                         size += 3;
1739                 else
1740                         size += ceil(log10(bufs[i].length + 1)) +
1741                                 bufs[i].length + 2;
1742         }
1743
1744         ns->length = size;
1745         ns->value = malloc(ns->length);
1746         if (!ns->value) {
1747                 ns->length = 0;
1748                 return -1;
1749         }
1750
1751         ptr = ns->value;
1752         for (i = 0; i < numbufs; i++) {
1753                 /* size */
1754                 rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
1755                 ptr += rc;
1756
1757                 /* contents */
1758                 memcpy(ptr, bufs[i].value, bufs[i].length);
1759                 ptr += bufs[i].length;
1760
1761                 /* delimeter */
1762                 *ptr++ = ',';
1763
1764                 size -= bufs[i].length + rc + 1;
1765
1766                 /* should not happen */
1767                 if (size < 0)
1768                         abort();
1769         }
1770
1771         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1772         return 0;
1773 }