Whamcloud - gitweb
LU-8590 utils: fix minor issues in lgss_sk usage
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Author: Jeremy Filizetti <jfilizet@iu.edu>
26  */
27
28 #include <fcntl.h>
29 #include <limits.h>
30 #include <math.h>
31 #include <string.h>
32 #include <stdbool.h>
33 #include <unistd.h>
34 #include <openssl/dh.h>
35 #include <openssl/engine.h>
36 #include <openssl/err.h>
37 #include <openssl/hmac.h>
38 #include <sys/types.h>
39 #include <sys/stat.h>
40 #include <lnet/nidstr.h>
41
42 #include "sk_utils.h"
43 #include "write_bytes.h"
44
45 #define SK_PBKDF2_ITERATIONS 10000
46
47 static struct sk_crypt_type sk_crypt_types[] = {
48         [SK_CRYPT_AES256_CTR] = {
49                 .sct_name = "ctr(aes)",
50                 .sct_bytes = 32,
51         },
52 };
53
54 static struct sk_hmac_type sk_hmac_types[] = {
55         [SK_HMAC_SHA256] = {
56                 .sht_name = "hmac(sha256)",
57                 .sht_bytes = 32,
58         },
59         [SK_HMAC_SHA512] = {
60                 .sht_name = "hmac(sha512)",
61                 .sht_bytes = 64,
62         },
63 };
64
65 #ifdef _NEW_BUILD_
66 # include "lgss_utils.h"
67 #else
68 # include "gss_util.h"
69 # include "gss_oids.h"
70 # include "err_util.h"
71 #endif
72
73 #ifdef _ERR_UTIL_H_
74 /**
75  * Initializes logging
76  * \param[in]   program         Program name to output
77  * \param[in]   verbose         Verbose flag
78  * \param[in]   fg              Whether or not to run in foreground
79  *
80  */
81 void sk_init_logging(char *program, int verbose, int fg)
82 {
83         initerr(program, verbose, fg);
84 }
85 #endif
86
87 /**
88  * Loads the key from \a filename and returns the struct sk_keyfile_config.
89  * It should be freed by the caller.
90  *
91  * \param[in]   filename                Disk or key payload data
92  *
93  * \return      sk_keyfile_config       sucess
94  * \return      NULL                    failure
95  */
96 struct sk_keyfile_config *sk_read_file(char *filename)
97 {
98         struct sk_keyfile_config *config;
99         char *ptr;
100         size_t rc;
101         size_t remain;
102         int fd;
103
104         config = malloc(sizeof(*config));
105         if (!config) {
106                 printerr(0, "Failed to allocate memory for config\n");
107                 return NULL;
108         }
109
110         /* allow standard input override */
111         if (strcmp(filename, "-") == 0)
112                 fd = STDIN_FILENO;
113         else
114                 fd = open(filename, O_RDONLY);
115
116         if (fd == -1) {
117                 printerr(0, "Error opening key file '%s': %s\n", filename,
118                          strerror(errno));
119                 goto out_free;
120         } else if (fd != STDIN_FILENO) {
121                 struct stat st;
122
123                 rc = fstat(fd, &st);
124                 if (rc == 0 && (st.st_mode & ~0600))
125                         fprintf(stderr, "warning: "
126                                 "secret key '%s' has insecure file mode %#o\n",
127                                 filename, st.st_mode);
128         }
129
130         ptr = (char *)config;
131         remain = sizeof(*config);
132         while (remain > 0) {
133                 rc = read(fd, ptr, remain);
134                 if (rc == -1) {
135                         if (errno == EINTR)
136                                 continue;
137                         printerr(0, "read() failed on %s: %s\n", filename,
138                                  strerror(errno));
139                         goto out_close;
140                 } else if (rc == 0) {
141                         printerr(0, "File %s does not have a complete key\n",
142                                  filename);
143                         goto out_close;
144                 }
145                 ptr += rc;
146                 remain -= rc;
147         }
148
149         if (fd != STDIN_FILENO)
150                 close(fd);
151         sk_config_disk_to_cpu(config);
152         return config;
153
154 out_close:
155         close(fd);
156 out_free:
157         free(config);
158         return NULL;
159 }
160
161 /**
162  * Checks if a key matching \a description is found in the keyring for
163  * logging purposes and then attempts to load \a payload of \a psize into a key
164  * with \a description.
165  *
166  * \param[in]   payload         Key payload
167  * \param[in]   psize           Payload size
168  * \param[in]   description     Description used for key in keyring
169  *
170  * \return      0       sucess
171  * \return      -1      failure
172  */
173 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
174                                 const char *description)
175 {
176         struct sk_keyfile_config payload;
177         key_serial_t key;
178
179         memcpy(&payload, skc, sizeof(*skc));
180
181         /* In the keyring use the disk layout so keyctl pipe can be used */
182         sk_config_cpu_to_disk(&payload);
183
184         /* Check to see if a key is already loaded matching description */
185         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
186         if (key != -1)
187                 printerr(2, "Key %d found in session keyring, replacing\n",
188                          key);
189
190         key = add_key("user", description, &payload, sizeof(payload),
191                       KEY_SPEC_USER_KEYRING);
192         if (key != -1)
193                 printerr(2, "Added key %d with description %s\n", key,
194                          description);
195         else
196                 printerr(0, "Failed to add key with %s\n", description);
197
198         return key;
199 }
200
201 /**
202  * Reads the key from \a path, verifies it and loads into the session keyring
203  * using a description determined by the the \a type.  Existing keys with the
204  * same description are replaced.
205  *
206  * \param[in]   path    Path to key file
207  * \param[in]   type    Type of key to load which determines the description
208  *
209  * \return      0       sucess
210  * \return      -1      failure
211  */
212 int sk_load_keyfile(char *path)
213 {
214         struct sk_keyfile_config *config;
215         char description[SK_DESCRIPTION_SIZE + 1];
216         struct stat buf;
217         int i;
218         int rc;
219         int rc2 = -1;
220
221         rc = stat(path, &buf);
222         if (rc == -1) {
223                 printerr(0, "stat() failed for file %s: %s\n", path,
224                          strerror(errno));
225                 return rc2;
226         }
227
228         config = sk_read_file(path);
229         if (!config)
230                 return rc2;
231
232         /* Similar to ssh, require adequate care of key files */
233         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
234                 printerr(0, "Shared key files must be read/writeable only by "
235                          "owner\n");
236                 return -1;
237         }
238
239         if (sk_validate_config(config))
240                 goto out;
241
242         /* The server side can have multiple key files per file system so
243          * the nodemap name is appended to the key description to uniquely
244          * identify it */
245         if (config->skc_type & SK_TYPE_MGS) {
246                 /* Any key can be an MGS key as long as we are told to use it */
247                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
248                               config->skc_nodemap);
249                 if (rc >= SK_DESCRIPTION_SIZE)
250                         goto out;
251                 if (sk_load_key(config, description) == -1)
252                         goto out;
253         }
254         if (config->skc_type & SK_TYPE_SERVER) {
255                 /* Server keys need to have the file system name in the key */
256                 if (!config->skc_fsname) {
257                         printerr(0, "Key configuration has no file system "
258                                  "attribute.  Can't load as server type\n");
259                         goto out;
260                 }
261                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
262                               config->skc_fsname, config->skc_nodemap);
263                 if (rc >= SK_DESCRIPTION_SIZE)
264                         goto out;
265                 if (sk_load_key(config, description) == -1)
266                         goto out;
267         }
268         if (config->skc_type & SK_TYPE_CLIENT) {
269                 /* Load client file system key */
270                 if (config->skc_fsname) {
271                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
272                                       "lustre:%s", config->skc_fsname);
273                         if (rc >= SK_DESCRIPTION_SIZE)
274                                 goto out;
275                         if (sk_load_key(config, description) == -1)
276                                 goto out;
277                 }
278
279                 /* Load client MGC keys */
280                 for (i = 0; i < MAX_MGSNIDS; i++) {
281                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
282                                 continue;
283                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
284                                       "lustre:MGC%s",
285                                       libcfs_nid2str(config->skc_mgsnids[i]));
286                         if (rc >= SK_DESCRIPTION_SIZE)
287                                 goto out;
288                         if (sk_load_key(config, description) == -1)
289                                 goto out;
290                 }
291         }
292
293         rc2 = 0;
294
295 out:
296         free(config);
297         return rc2;
298 }
299
300 /**
301  * Byte swaps config from cpu format to disk
302  *
303  * \param[in,out]       config          sk_keyfile_config to swap
304  */
305 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
306 {
307         int i;
308
309         if (!config)
310                 return;
311
312         config->skc_version = htobe32(config->skc_version);
313         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
314         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
315         config->skc_expire = htobe32(config->skc_expire);
316         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
317         config->skc_prime_bits = htobe32(config->skc_prime_bits);
318
319         for (i = 0; i < MAX_MGSNIDS; i++)
320                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
321
322         return;
323 }
324
325 /**
326  * Byte swaps config from disk format to cpu
327  *
328  * \param[in,out]       config          sk_keyfile_config to swap
329  */
330 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
331 {
332         int i;
333
334         if (!config)
335                 return;
336
337         config->skc_version = be32toh(config->skc_version);
338         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
339         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
340         config->skc_expire = be32toh(config->skc_expire);
341         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
342         config->skc_prime_bits = be32toh(config->skc_prime_bits);
343
344         for (i = 0; i < MAX_MGSNIDS; i++)
345                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
346
347         return;
348 }
349
350 /**
351  * Verifies the on key payload format is valid
352  *
353  * \param[in]   config          sk_keyfile_config
354  *
355  * \return      -1      failure
356  * \return      0       success
357  */
358 int sk_validate_config(const struct sk_keyfile_config *config)
359 {
360         int i;
361
362         if (!config) {
363                 printerr(0, "Null configuration passed\n");
364                 return -1;
365         }
366         if (config->skc_version != SK_CONF_VERSION) {
367                 printerr(0, "Invalid version\n");
368                 return -1;
369         }
370         if (config->skc_hmac_alg >= SK_HMAC_MAX) {
371                 printerr(0, "Invalid HMAC algorithm\n");
372                 return -1;
373         }
374         if (config->skc_crypt_alg >= SK_CRYPT_MAX) {
375                 printerr(0, "Invalid crypt algorithm\n");
376                 return -1;
377         }
378         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
379                 /* Try to limit key expiration to some reasonable minimum and
380                  * also prevent values over INT_MAX because there appears
381                  * to be a type conversion issue */
382                 printerr(0, "Invalid expiration time should be between %d "
383                          "and %d\n", 60, INT_MAX);
384                 return -1;
385         }
386         if (config->skc_prime_bits % 8 != 0 ||
387             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
388                 printerr(0, "Invalid session key length must be a multiple of 8"
389                          " and less then %d bits\n",
390                          SK_MAX_P_BYTES * 8);
391                 return -1;
392         }
393         if (config->skc_shared_keylen % 8 != 0 ||
394             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
395                 printerr(0, "Invalid shared key max length must be a multiple "
396                          "of 8 and less then %d bits\n",
397                          SK_MAX_KEYLEN_BYTES * 8);
398                 return -1;
399         }
400
401         /* Check for terminating nulls on strings */
402         for (i = 0; i < sizeof(config->skc_fsname) &&
403              config->skc_fsname[i] != '\0';  i++)
404                 ; /* empty loop */
405         if (i == sizeof(config->skc_fsname)) {
406                 printerr(0, "File system name not null terminated\n");
407                 return -1;
408         }
409
410         for (i = 0; i < sizeof(config->skc_nodemap) &&
411              config->skc_nodemap[i] != '\0';  i++)
412                 ; /* empty loop */
413         if (i == sizeof(config->skc_nodemap)) {
414                 printerr(0, "Nodemap name not null terminated\n");
415                 return -1;
416         }
417
418         if (config->skc_type == SK_TYPE_INVALID) {
419                 printerr(0, "Invalid key type\n");
420                 return -1;
421         }
422
423         return 0;
424 }
425
426 /**
427  * Hashes \a string and places the hash in \a hash
428  * at \a hash
429  *
430  * \param[in]           string          Null terminated string to hash
431  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
432  * \param[in,out]       hash            gss_buffer_desc to hold the result
433  *
434  * \return      -1      failure
435  * \return      0       success
436  */
437 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
438                           gss_buffer_desc *hash)
439 {
440         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
441         size_t len = strlen(string);
442         unsigned int hashlen;
443
444         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
445                 goto out_err;
446         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
447                 goto out_err;
448         if (!EVP_DigestUpdate(ctx, string, len))
449                 goto out_err;
450         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
451                 goto out_err;
452
453         EVP_MD_CTX_destroy(ctx);
454         hash->length = hashlen;
455         return 0;
456
457 out_err:
458         EVP_MD_CTX_destroy(ctx);
459         return -1;
460 }
461
462 /**
463  * Hashes \a string and verifies the resulting hash matches the value
464  * in \a current_hash
465  *
466  * \param[in]           string          Null terminated string to hash
467  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
468  * \param[in,out]       current_hash    gss_buffer_desc to compare to
469  *
470  * \return      gss error       failure
471  * \return      GSS_S_COMPLETE  success
472  */
473 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
474                         const gss_buffer_desc *current_hash)
475 {
476         gss_buffer_desc hash;
477         unsigned char hashbuf[EVP_MAX_MD_SIZE];
478
479         hash.value = hashbuf;
480         hash.length = sizeof(hashbuf);
481
482         if (sk_hash_string(string, hash_alg, &hash))
483                 return GSS_S_FAILURE;
484         if (current_hash->length != hash.length)
485                 return GSS_S_DEFECTIVE_TOKEN;
486         if (memcmp(current_hash->value, hash.value, hash.length))
487                 return GSS_S_BAD_SIG;
488
489         return GSS_S_COMPLETE;
490 }
491
492 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
493                                        const char *mgsnid)
494 {
495         lnet_nid_t nid;
496         int i;
497
498         nid = libcfs_str2nid(mgsnid);
499         if (nid == LNET_NID_ANY)
500                 return 0;
501
502         for (i = 0; i < MAX_MGSNIDS; i++)
503                 if  (config->skc_mgsnids[i] == nid)
504                         return 1;
505         return 0;
506 }
507
508 /**
509  * Create an sk_cred structure populated with initial configuration info and the
510  * key.  \a tgt and \a nodemap are used in determining the expected key
511  * description so the key can be found by searching the keyring.
512  * This is done because there is no easy way to pass keys from the mount command
513  * all the way to the request_key call.  In addition any keys can be dynamically
514  * added to the keyrings and still found.  The keyring that needs to be used
515  * must be the session keyring.
516  *
517  * \param[in]   tgt             Target file system
518  * \param[in]   nodemap         Cluster name for the key.  This correlates to
519  *                              the nodemap name and is used by the server side.
520  *                              For the client this will be NULL.
521  * \param[in]   flags           Flags for the credentials
522  *
523  * \return      sk_cred Allocated struct sk_cred on success
524  * \return      NULL    failure
525  */
526 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
527                                const uint32_t flags)
528 {
529         struct sk_keyfile_config *config;
530         struct sk_kernel_ctx *kctx;
531         struct sk_cred *skc = NULL;
532         char description[SK_DESCRIPTION_SIZE + 1];
533         char fsname[MTI_NAME_MAXLEN + 1];
534         const char *mgsnid = NULL;
535         char *ptr;
536         long sk_key;
537         int keylen;
538         int len;
539         int rc;
540
541         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
542                  tgt, nodemap);
543
544         memset(description, 0, sizeof(description));
545         memset(fsname, 0, sizeof(fsname));
546
547         /* extract the file system name from target */
548         ptr = index(tgt, '-');
549         if (!ptr) {
550                 len = strlen(tgt);
551
552                 /* This must be an MGC target */
553                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
554                         printerr(0, "Invalid target name\n");
555                         return NULL;
556                 }
557                 mgsnid = tgt + 3;
558         } else {
559                 len = ptr - tgt;
560         }
561
562         if (len > MTI_NAME_MAXLEN) {
563                 printerr(0, "Invalid target name\n");
564                 return NULL;
565         }
566         memcpy(fsname, tgt, len);
567
568         if (nodemap) {
569                 if (mgsnid)
570                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
571                                       "lustre:MGS:%s", nodemap);
572                 else
573                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
574                                       "lustre:%s:%s", fsname, nodemap);
575         } else {
576                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
577                               fsname);
578         }
579
580         if (rc >= SK_DESCRIPTION_SIZE) {
581                 printerr(0, "Invalid key description\n");
582                 return NULL;
583         }
584
585         /* It may be a good idea to move Lustre keys to the gss_keyring
586          * (lgssc) type so that they expire when Lustre modules are removed.
587          * Unfortunately it can't be done at mount time because the mount
588          * syscall could trigger the Lustre modules to load and until that
589          * point we don't have a lgssc key type.
590          *
591          * TODO: Query the community for a consensus here  */
592         printerr(2, "Searching for key with description: %s\n", description);
593         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
594                                description, 0);
595         if (sk_key == -1) {
596                 printerr(1, "No key found for %s\n", description);
597                 return NULL;
598         }
599
600         keylen = keyctl_read_alloc(sk_key, (void **)&config);
601         if (keylen == -1) {
602                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
603                          strerror(errno));
604                 return NULL;
605         } else if (keylen != sizeof(*config)) {
606                 printerr(0, "Unexpected key size: %d returned for key %ld, "
607                          "expected %zu bytes\n",
608                          keylen, sk_key, sizeof(*config));
609                 goto out_err;
610         }
611
612         sk_config_disk_to_cpu(config);
613
614         if (sk_validate_config(config)) {
615                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
616                 goto out_err;
617         }
618
619         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
620                 printerr(0, "Target name does not match key's MGS NIDs\n");
621                 goto out_err;
622         }
623
624         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
625                 printerr(0, "Target name does not match key's file system\n");
626                 goto out_err;
627         }
628
629         skc = malloc(sizeof(*skc));
630         if (!skc) {
631                 printerr(0, "Failed to allocate memory for sk_cred\n");
632                 goto out_err;
633         }
634
635         /* this initializes all gss_buffer_desc to empty as well */
636         memset(skc, 0, sizeof(*skc));
637
638         skc->sc_flags = flags;
639         skc->sc_tgt.length = strlen(tgt) + 1;
640         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
641         if (!skc->sc_tgt.value) {
642                 printerr(0, "Failed to allocate memory for target\n");
643                 goto out_err;
644         }
645         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
646
647         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
648         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
649         if (!skc->sc_nodemap_hash.value) {
650                 printerr(0, "Failed to allocate memory for nodemap hash\n");
651                 goto out_err;
652         }
653
654         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
655                            &skc->sc_nodemap_hash)) {
656                 printerr(0, "Failed to generate hash for nodemap name\n");
657                 goto out_err;
658         }
659
660         kctx = &skc->sc_kctx;
661         kctx->skc_version = config->skc_version;
662         kctx->skc_hmac_alg = config->skc_hmac_alg;
663         kctx->skc_crypt_alg = config->skc_crypt_alg;
664         kctx->skc_expire = config->skc_expire;
665
666         /* key payload format is in bits, convert to bytes */
667         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
668         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
669         if (!kctx->skc_shared_key.value) {
670                 printerr(0, "Failed to allocate memory for shared key\n");
671                 goto out_err;
672         }
673         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
674                kctx->skc_shared_key.length);
675
676         skc->sc_p.length = config->skc_prime_bits / 8;
677         skc->sc_p.value = malloc(skc->sc_p.length);
678         if (!skc->sc_p.value) {
679                 printerr(0, "Failed to allocate p\n");
680                 goto out_err;
681         }
682         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
683
684         free(config);
685
686         return skc;
687
688 out_err:
689         sk_free_cred(skc);
690
691         free(config);
692         return NULL;
693 }
694
695 /**
696  * Populates the DH parameters for the DHKE
697  *
698  * \param[in,out]       skc             Shared key credentials structure to
699  *                                      populate with DH parameters
700  *
701  * \retval      GSS_S_COMPLETE  success
702  * \retval      GSS_S_FAILURE   failure
703  */
704 uint32_t sk_gen_params(struct sk_cred *skc)
705 {
706         uint32_t random;
707         int rc;
708
709         /* Random value used by both the request and response as part of the
710          * key binding material.  This also should ensure we have unqiue
711          * tokens that are sent to the remote server which is important because
712          * the token is hashed for the sunrpc cache lookups and a failure there
713          * would cause connection attempts to fail indefinitely due to the large
714          * timeout value on the server side */
715         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
716                 printerr(0, "Failed to get data for random parameter: %s\n",
717                          ERR_error_string(ERR_get_error(), NULL));
718                 return GSS_S_FAILURE;
719         }
720
721         /* The random value will always be used in byte range operations
722          * so we keep it as big endian from this point on */
723         skc->sc_kctx.skc_host_random = random;
724
725         /* Populate DH parameters */
726         skc->sc_params = DH_new();
727         if (!skc->sc_params) {
728                 printerr(0, "Failed to allocate DH\n");
729                 return GSS_S_FAILURE;
730         }
731
732         skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
733         if (!skc->sc_params->p) {
734                 printerr(0, "Failed to convert binary to BIGNUM\n");
735                 return GSS_S_FAILURE;
736         }
737
738         /* We use a static generator for shared key */
739         skc->sc_params->g = BN_new();
740         if (!skc->sc_params->g) {
741                 printerr(0, "Failed to allocate new BIGNUM\n");
742                 return GSS_S_FAILURE;
743         }
744         if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
745                 printerr(0, "Failed to set g value for DH params\n");
746                 return GSS_S_FAILURE;
747         }
748
749         /* Verify that we have a safe prime and valid generator */
750         if (DH_check(skc->sc_params, &rc) != 1) {
751                 printerr(0, "DH_check() failed: %d\n", rc);
752                 return GSS_S_FAILURE;
753         } else if (rc) {
754                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
755                 return GSS_S_FAILURE;
756         }
757
758         if (DH_generate_key(skc->sc_params) != 1) {
759                 printerr(0, "Failed to generate public DH key: %s\n",
760                          ERR_error_string(ERR_get_error(), NULL));
761                 return GSS_S_FAILURE;
762         }
763
764         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
765         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
766         if (!skc->sc_pub_key.value) {
767                 printerr(0, "Failed to allocate memory for public key\n");
768                 return GSS_S_FAILURE;
769         }
770
771         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
772
773         return GSS_S_COMPLETE;
774 }
775
776 /**
777  * Convert SK hash algorithm into openssl message digest
778  *
779  * \param[in,out]       alg             SK hash algorithm
780  *
781  * \retval              EVP_MD
782  */
783 static inline const EVP_MD *sk_hash_to_evp_md(enum sk_hmac_alg alg)
784 {
785         switch (alg) {
786         case SK_HMAC_SHA256:
787                 return EVP_sha256();
788         case SK_HMAC_SHA512:
789                 return EVP_sha512();
790         default:
791                 return EVP_md_null();
792         }
793 }
794
795 /**
796  * Signs (via HMAC) the parameters used only in the key initialization protocol.
797  *
798  * \param[in]           key             Key to use for HMAC
799  * \param[in]           bufs            Array of gss_buffer_desc to generate
800  *                                      HMAC for
801  * \param[in]           numbufs         Number of buffers in array
802  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
803  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
804  *                                      in this gss_buffer_desc.  Caller must
805  *                                      free this.
806  *
807  * \retval      0       success
808  * \retval      -1      failure
809  */
810 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
811                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
812 {
813         HMAC_CTX hctx;
814         unsigned int hashlen = EVP_MD_size(hash_alg);
815         int i;
816         int rc = -1;
817
818         if (hash_alg == EVP_md_null()) {
819                 printerr(0, "Invalid hash algorithm\n");
820                 return -1;
821         }
822
823         HMAC_CTX_init(&hctx);
824
825         hmac->length = hashlen;
826         hmac->value = malloc(hashlen);
827         if (!hmac->value) {
828                 printerr(0, "Failed to allocate memory for HMAC\n");
829                 goto out;
830         }
831
832         if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
833                 printerr(0, "Failed to init HMAC\n");
834                 goto out;
835         }
836
837         for (i = 0; i < numbufs; i++) {
838                 if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
839                         printerr(0, "Failed to update HMAC\n");
840                         goto out;
841                 }
842         }
843
844         /* The result gets populated in hmac */
845         if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
846                 printerr(0, "Failed to finalize HMAC\n");
847                 goto out;
848         }
849
850         if (hmac->length != hashlen) {
851                 printerr(0, "HMAC size does not match expected\n");
852                 goto out;
853         }
854
855         rc = 0;
856 out:
857         HMAC_CTX_cleanup(&hctx);
858         return rc;
859 }
860
861 /**
862  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
863  * and verifies against \a hmac.
864  *
865  * \param[in]   skc             Shared key credentials
866  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
867  * \param[in]   numbufs         Number of buffers in array
868  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
869  * \param[in]   hmac            HMAC to verify against
870  *
871  * \retval      GSS_S_COMPLETE  success (match)
872  * \retval      gss error       failure
873  */
874 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
875                         const int numbufs, const EVP_MD *hash_alg,
876                         gss_buffer_desc *hmac)
877 {
878         gss_buffer_desc bufs_hmac;
879         int rc;
880
881         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
882                          &bufs_hmac)) {
883                 printerr(0, "Failed to sign buffers to verify HMAC\n");
884                 if (bufs_hmac.value)
885                         free(bufs_hmac.value);
886                 return GSS_S_FAILURE;
887         }
888
889         if (hmac->length != bufs_hmac.length) {
890                 printerr(0, "Invalid HMAC size\n");
891                 free(bufs_hmac.value);
892                 return GSS_S_BAD_SIG;
893         }
894
895         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
896         free(bufs_hmac.value);
897
898         if (rc)
899                 return GSS_S_BAD_SIG;
900
901         return GSS_S_COMPLETE;
902 }
903
904 /**
905  * Cleanup an sk_cred freeing any resources
906  *
907  * \param[in,out]       skc     Shared key credentials to free
908  */
909 void sk_free_cred(struct sk_cred *skc)
910 {
911         if (!skc)
912                 return;
913
914         if (skc->sc_p.value)
915                 free(skc->sc_p.value);
916         if (skc->sc_pub_key.value)
917                 free(skc->sc_pub_key.value);
918         if (skc->sc_tgt.value)
919                 free(skc->sc_tgt.value);
920         if (skc->sc_nodemap_hash.value)
921                 free(skc->sc_nodemap_hash.value);
922         if (skc->sc_hmac.value)
923                 free(skc->sc_hmac.value);
924
925         /* Overwrite keys and IV before freeing */
926         if (skc->sc_dh_shared_key.value) {
927                 memset(skc->sc_dh_shared_key.value, 0,
928                        skc->sc_dh_shared_key.length);
929                 free(skc->sc_dh_shared_key.value);
930         }
931         if (skc->sc_kctx.skc_hmac_key.value) {
932                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
933                        skc->sc_kctx.skc_hmac_key.length);
934                 free(skc->sc_kctx.skc_hmac_key.value);
935         }
936         if (skc->sc_kctx.skc_encrypt_key.value) {
937                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
938                        skc->sc_kctx.skc_encrypt_key.length);
939                 free(skc->sc_kctx.skc_encrypt_key.value);
940         }
941         if (skc->sc_kctx.skc_shared_key.value) {
942                 memset(skc->sc_kctx.skc_shared_key.value, 0,
943                        skc->sc_kctx.skc_shared_key.length);
944                 free(skc->sc_kctx.skc_shared_key.value);
945         }
946         if (skc->sc_kctx.skc_session_key.value) {
947                 memset(skc->sc_kctx.skc_session_key.value, 0,
948                        skc->sc_kctx.skc_session_key.length);
949                 free(skc->sc_kctx.skc_session_key.value);
950         }
951
952         if (skc->sc_params)
953                 DH_free(skc->sc_params);
954
955         free(skc);
956         skc = NULL;
957 }
958
959 /* This function handles key derivation using the hash algorithm specified in
960  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
961  * \a origin_key to produce a \a derived_key.  The first element of the
962  * key_binding_bufs array is reserved for the counter used in the KDF.  The
963  * derived key in \a derived_key could differ in size from \a origin_key and
964  * must be populated with the expected size and a valid buffer to hold the
965  * contents.
966  *
967  * If the derived key size is greater than the HMAC algorithm size it will be
968  * a done using several iterations of a counter and the key binding bufs.
969  *
970  * If the size is smaller it will take copy the first N bytes necessary to
971  * fill the derived key. */
972 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
973            gss_buffer_desc *key_binding_bufs, int numbufs, int hmac_alg)
974 {
975         size_t remain;
976         size_t bytes;
977         uint32_t counter;
978         char *keydata;
979         gss_buffer_desc tmp_hash;
980         int i;
981         int rc;
982
983         if (numbufs < 1)
984                 return -EINVAL;
985
986         /* Use a counter as the first buffer followed by the key binding
987          * buffers in the event we need more than one a single cycle to
988          * produced a symmetric key large enough in size */
989         key_binding_bufs[0].value = &counter;
990         key_binding_bufs[0].length = sizeof(counter);
991
992         remain = derived_key->length;
993         keydata = derived_key->value;
994         i = 0;
995         while (remain > 0) {
996                 counter = htobe32(i++);
997                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
998                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
999                 if (rc) {
1000                         if (tmp_hash.value)
1001                                 free(tmp_hash.value);
1002                         return rc;
1003                 }
1004
1005                 LASSERT(sk_hmac_types[hmac_alg].sht_bytes ==
1006                         tmp_hash.length);
1007
1008                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1009                 memcpy(keydata, tmp_hash.value, bytes);
1010                 free(tmp_hash.value);
1011                 remain -= bytes;
1012                 keydata += bytes;
1013         }
1014
1015         return 0;
1016 }
1017
1018 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1019  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1020  * (Sep 2014) Section 5.5.1
1021  *
1022  * \param[in,out]       skc             Shared key credentials structure with
1023  *
1024  * \return      -1              failure
1025  * \return      0               success
1026  */
1027 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1028                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1029 {
1030         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1031         gss_buffer_desc *session_key = &kctx->skc_session_key;
1032         gss_buffer_desc bufs[5];
1033         int rc = -1;
1034
1035         session_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1036         session_key->value = malloc(session_key->length);
1037         if (!session_key->value) {
1038                 printerr(0, "Failed to allocate memory for session key\n");
1039                 return rc;
1040         }
1041
1042         /* Key binding info ordering
1043          * 1. Reserved for counter
1044          * 1. DH shared key
1045          * 2. Client's NIDs
1046          * 3. Client's token
1047          * 4. Server's token */
1048         bufs[0].value = NULL;
1049         bufs[0].length = 0;
1050         bufs[1] = skc->sc_dh_shared_key;
1051         bufs[2].value = &client_nid;
1052         bufs[2].length = sizeof(client_nid);
1053         bufs[3] = *client_token;
1054         bufs[4] = *server_token;
1055
1056         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1057                       5, kctx->skc_hmac_alg);
1058 }
1059
1060 /* Uses the session key to create an HMAC key and encryption key.  In
1061  * integrity mode the session key used to generate the HMAC key uses
1062  * session information which is available on the wire but by creating
1063  * a session based HMAC key we can prevent potential replay as both the
1064  * client and server have random numbers used as part of the key creation.
1065  *
1066  * The keys used for integrity and privacy are formulated as below using
1067  * the session key that is the output of the key derivation function.  The
1068  * HMAC algorithm is determined by the shared key algorithm selected in the
1069  * key file.
1070  *
1071  * For ski mode:
1072  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1073  *
1074  * For skpi mode:
1075  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1076  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1077  *
1078  * \param[in,out]       skc             Shared key credentials structure with
1079  *
1080  * \return      -1              failure
1081  * \return      0               success
1082  */
1083 int sk_compute_keys(struct sk_cred *skc)
1084 {
1085         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1086         gss_buffer_desc *session_key = &kctx->skc_session_key;
1087         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1088         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1089         char *encrypt = "Encrypt";
1090         char *integrity = "Integrity";
1091         int rc;
1092
1093         hmac_key->length = sk_hmac_types[kctx->skc_hmac_alg].sht_bytes;
1094         hmac_key->value = malloc(hmac_key->length);
1095         if (!hmac_key->value)
1096                 return -ENOMEM;
1097
1098         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1099                                session_key->length, SK_PBKDF2_ITERATIONS,
1100                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1101                                hmac_key->length, hmac_key->value);
1102         if (rc == 0)
1103                 return -EINVAL;
1104
1105         /* Encryption key is only populated in privacy mode */
1106         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1107                 return 0;
1108
1109         encrypt_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1110         encrypt_key->value = malloc(encrypt_key->length);
1111         if (!encrypt_key->value)
1112                 return -ENOMEM;
1113
1114         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1115                                session_key->length, SK_PBKDF2_ITERATIONS,
1116                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1117                                encrypt_key->length, encrypt_key->value);
1118         if (rc == 0)
1119                 return -EINVAL;
1120
1121         return 0;
1122 }
1123
1124 /**
1125  * Computes a session key based on the DH parameters from the host and its peer
1126  *
1127  * \param[in,out]       skc             Shared key credentials structure with
1128  *                                      the session key populated with the
1129  *                                      compute key
1130  * \param[in]           pub_key         Public key returned from peer in
1131  *                                      gss_buffer_desc
1132  * \return      gss error               failure
1133  * \return      GSS_S_COMPLETE          success
1134  */
1135 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1136 {
1137         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1138         BIGNUM *remote_pub_key;
1139         int status;
1140         uint32_t rc = GSS_S_FAILURE;
1141
1142         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1143         if (!remote_pub_key) {
1144                 printerr(0, "Failed to convert binary to BIGNUM\n");
1145                 return rc;
1146         }
1147
1148         dh_shared->length = DH_size(skc->sc_params);
1149         dh_shared->value = malloc(dh_shared->length);
1150         if (!dh_shared->value) {
1151                 printerr(0, "Failed to allocate memory for computed shared "
1152                          "secret key\n");
1153                 goto out_err;
1154         }
1155
1156         /* This compute the shared key from the DHKE */
1157         status = DH_compute_key(dh_shared->value, remote_pub_key,
1158                                 skc->sc_params);
1159         if (status == -1) {
1160                 printerr(0, "DH_compute_key() failed: %s\n",
1161                          ERR_error_string(ERR_get_error(), NULL));
1162                 goto out_err;
1163         } else if (status < dh_shared->length) {
1164                 printerr(0, "DH_compute_key() returned a short key of %d "
1165                          "bytes, expected: %zu\n", status, dh_shared->length);
1166                 rc = GSS_S_DEFECTIVE_TOKEN;
1167                 goto out_err;
1168         }
1169
1170         rc = GSS_S_COMPLETE;
1171
1172 out_err:
1173         BN_free(remote_pub_key);
1174         return rc;
1175 }
1176
1177 /**
1178  * Creates a serialized buffer for the kernel in the order of struct
1179  * sk_kernel_ctx.
1180  *
1181  * \param[in,out]       skc             Shared key credentials structure
1182  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1183  *                                      Caller must free this buffer.
1184  *
1185  * \return      0       success
1186  * \return      -1      failure
1187  */
1188 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1189 {
1190         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1191         char *p, *end;
1192         size_t bufsize;
1193
1194         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1195                   kctx->skc_encrypt_key.length;
1196
1197         ctx_token->value = malloc(bufsize);
1198         if (!ctx_token->value)
1199                 return -1;
1200         ctx_token->length = bufsize;
1201
1202         p = ctx_token->value;
1203         end = p + ctx_token->length;
1204
1205         if (WRITE_BYTES(&p, end, kctx->skc_version))
1206                 return -1;
1207         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1208                 return -1;
1209         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1210                 return -1;
1211         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1212                 return -1;
1213         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1214                 return -1;
1215         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1216                 return -1;
1217         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1218                 return -1;
1219         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1220                 return -1;
1221
1222         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1223
1224         return 0;
1225 }
1226
1227 /**
1228  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1229  * up to \a numbufs.  Memory is allocated for each value and length
1230  * will be populated with the length
1231  *
1232  * \param[in,out]       bufs    Array of gss_buffer_descs
1233  * \param[in,out]       numbufs number of gss_buffer_desc in array
1234  * \param[in]           ns      netstring to decode
1235  *
1236  * \return      buffers populated       success
1237  * \return      -1                      failure
1238  */
1239 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1240 {
1241         char *ptr = ns->value;
1242         size_t remain = ns->length;
1243         unsigned int size;
1244         int digits;
1245         int sep;
1246         int rc;
1247         int i;
1248
1249         for (i = 0; i < numbufs; i++) {
1250                 /* read the size of first buffer */
1251                 rc = sscanf(ptr, "%9u", &size);
1252                 if (rc < 1)
1253                         goto out_err;
1254                 digits = (size) ? ceil(log10(size + 1)) : 1;
1255
1256                 /* sep of current string */
1257                 sep = size + digits + 2;
1258
1259                 /* check to make sure it's valid */
1260                 if (remain < sep || ptr[digits] != ':' ||
1261                     ptr[sep - 1] != ',')
1262                         goto out_err;
1263
1264                 bufs[i].length = size;
1265                 if (size == 0) {
1266                         bufs[i].value = NULL;
1267                 } else {
1268                         bufs[i].value = malloc(size);
1269                         if (!bufs[i].value)
1270                                 goto out_err;
1271                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1272                 }
1273
1274                 remain -= sep;
1275                 ptr += sep;
1276         }
1277
1278         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1279         return i;
1280
1281 out_err:
1282         while (i-- > 0) {
1283                 if (bufs[i].value)
1284                         free(bufs[i].value);
1285                 bufs[i].length = 0;
1286         }
1287         return -1;
1288 }
1289
1290 /**
1291  * Creates a netstring in a gss_buffer_desc that consists of all
1292  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1293  * as binary as it can contain null characters.
1294  *
1295  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1296  * \param[in]           numbufs         Number of buffers in array
1297  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1298  *                                      netstring
1299  *
1300  * \return      -1      failure
1301  * \return      0       success
1302  */
1303 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1304                         gss_buffer_desc *ns)
1305 {
1306         unsigned char *ptr;
1307         int size = 0;
1308         int rc;
1309         int i;
1310
1311         /* size of string in decimal, string size, colon, and comma */
1312         for (i = 0; i < numbufs; i++) {
1313
1314                 if (bufs[i].length == 0)
1315                         size += 3;
1316                 else
1317                         size += ceil(log10(bufs[i].length + 1)) +
1318                                 bufs[i].length + 2;
1319         }
1320
1321         ns->length = size;
1322         ns->value = malloc(ns->length);
1323         if (!ns->value) {
1324                 ns->length = 0;
1325                 return -1;
1326         }
1327
1328         ptr = ns->value;
1329         for (i = 0; i < numbufs; i++) {
1330                 /* size */
1331                 rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
1332                 ptr += rc;
1333
1334                 /* contents */
1335                 memcpy(ptr, bufs[i].value, bufs[i].length);
1336                 ptr += bufs[i].length;
1337
1338                 /* delimeter */
1339                 *ptr++ = ',';
1340
1341                 size -= bufs[i].length + rc + 1;
1342
1343                 /* should not happen */
1344                 if (size < 0)
1345                         abort();
1346         }
1347
1348         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1349         return 0;
1350 }