Whamcloud - gitweb
LU-8602 gss: Support GSS on linux 4.6+ kernels
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42 #include <lnet/nidstr.h>
43
44 #include "sk_utils.h"
45 #include "write_bytes.h"
46
47 #define SK_PBKDF2_ITERATIONS 10000
48
49 static struct sk_crypt_type sk_crypt_types[] = {
50         [SK_CRYPT_AES256_CTR] = {
51                 .cht_name = "ctr(aes)",
52                 .cht_bytes = 32,
53         },
54 };
55
56 /*
57 static struct sk_hmac_type sk_hmac_types[] = {
58         [SK_HMAC_SHA256] = {
59                 .cht_name = "sha256",
60                 .cht_bytes = 32,
61         },
62         [SK_HMAC_SHA512] = {
63                 .cht_name = "sha512",
64                 .cht_bytes = 64,
65         },
66 };*/
67
68 #ifdef _NEW_BUILD_
69 # include "lgss_utils.h"
70 #else
71 # include "gss_util.h"
72 # include "gss_oids.h"
73 # include "err_util.h"
74 #endif
75
76 #ifdef _ERR_UTIL_H_
77 /**
78  * Initializes logging
79  * \param[in]   program         Program name to output
80  * \param[in]   verbose         Verbose flag
81  * \param[in]   fg              Whether or not to run in foreground
82  *
83  */
84 void sk_init_logging(char *program, int verbose, int fg)
85 {
86         initerr(program, verbose, fg);
87 }
88 #endif
89
90 /**
91  * Loads the key from \a filename and returns the struct sk_keyfile_config.
92  * It should be freed by the caller.
93  *
94  * \param[in]   filename                Disk or key payload data
95  *
96  * \return      sk_keyfile_config       sucess
97  * \return      NULL                    failure
98  */
99 struct sk_keyfile_config *sk_read_file(char *filename)
100 {
101         struct sk_keyfile_config *config;
102         char *ptr;
103         size_t rc;
104         size_t remain;
105         int fd;
106
107         config = malloc(sizeof(*config));
108         if (!config) {
109                 printerr(0, "Failed to allocate memory for config\n");
110                 return NULL;
111         }
112
113         /* allow standard input override */
114         if (strcmp(filename, "-") == 0)
115                 fd = STDIN_FILENO;
116         else
117                 fd = open(filename, O_RDONLY);
118
119         if (fd == -1) {
120                 printerr(0, "Error opening key file '%s': %s\n", filename,
121                          strerror(errno));
122                 goto out_free;
123         } else if (fd != STDIN_FILENO) {
124                 struct stat st;
125
126                 rc = fstat(fd, &st);
127                 if (rc == 0 && (st.st_mode & ~0600))
128                         fprintf(stderr, "warning: "
129                                 "secret key '%s' has insecure file mode %#o\n",
130                                 filename, st.st_mode);
131         }
132
133         ptr = (char *)config;
134         remain = sizeof(*config);
135         while (remain > 0) {
136                 rc = read(fd, ptr, remain);
137                 if (rc == -1) {
138                         if (errno == EINTR)
139                                 continue;
140                         printerr(0, "read() failed on %s: %s\n", filename,
141                                  strerror(errno));
142                         goto out_close;
143                 } else if (rc == 0) {
144                         printerr(0, "File %s does not have a complete key\n",
145                                  filename);
146                         goto out_close;
147                 }
148                 ptr += rc;
149                 remain -= rc;
150         }
151
152         if (fd != STDIN_FILENO)
153                 close(fd);
154         sk_config_disk_to_cpu(config);
155         return config;
156
157 out_close:
158         close(fd);
159 out_free:
160         free(config);
161         return NULL;
162 }
163
164 /**
165  * Checks if a key matching \a description is found in the keyring for
166  * logging purposes and then attempts to load \a payload of \a psize into a key
167  * with \a description.
168  *
169  * \param[in]   payload         Key payload
170  * \param[in]   psize           Payload size
171  * \param[in]   description     Description used for key in keyring
172  *
173  * \return      0       sucess
174  * \return      -1      failure
175  */
176 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
177                                 const char *description)
178 {
179         struct sk_keyfile_config payload;
180         key_serial_t key;
181
182         memcpy(&payload, skc, sizeof(*skc));
183
184         /* In the keyring use the disk layout so keyctl pipe can be used */
185         sk_config_cpu_to_disk(&payload);
186
187         /* Check to see if a key is already loaded matching description */
188         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
189         if (key != -1)
190                 printerr(2, "Key %d found in session keyring, replacing\n",
191                          key);
192
193         key = add_key("user", description, &payload, sizeof(payload),
194                       KEY_SPEC_USER_KEYRING);
195         if (key != -1)
196                 printerr(2, "Added key %d with description %s\n", key,
197                          description);
198         else
199                 printerr(0, "Failed to add key with %s\n", description);
200
201         return key;
202 }
203
204 /**
205  * Reads the key from \a path, verifies it and loads into the session keyring
206  * using a description determined by the the \a type.  Existing keys with the
207  * same description are replaced.
208  *
209  * \param[in]   path    Path to key file
210  * \param[in]   type    Type of key to load which determines the description
211  *
212  * \return      0       sucess
213  * \return      -1      failure
214  */
215 int sk_load_keyfile(char *path)
216 {
217         struct sk_keyfile_config *config;
218         char description[SK_DESCRIPTION_SIZE + 1];
219         struct stat buf;
220         int i;
221         int rc;
222         int rc2 = -1;
223
224         rc = stat(path, &buf);
225         if (rc == -1) {
226                 printerr(0, "stat() failed for file %s: %s\n", path,
227                          strerror(errno));
228                 return rc2;
229         }
230
231         config = sk_read_file(path);
232         if (!config)
233                 return rc2;
234
235         /* Similar to ssh, require adequate care of key files */
236         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
237                 printerr(0, "Shared key files must be read/writeable only by "
238                          "owner\n");
239                 return -1;
240         }
241
242         if (sk_validate_config(config))
243                 goto out;
244
245         /* The server side can have multiple key files per file system so
246          * the nodemap name is appended to the key description to uniquely
247          * identify it */
248         if (config->skc_type & SK_TYPE_MGS) {
249                 /* Any key can be an MGS key as long as we are told to use it */
250                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
251                               config->skc_nodemap);
252                 if (rc >= SK_DESCRIPTION_SIZE)
253                         goto out;
254                 if (sk_load_key(config, description) == -1)
255                         goto out;
256         }
257         if (config->skc_type & SK_TYPE_SERVER) {
258                 /* Server keys need to have the file system name in the key */
259                 if (!config->skc_fsname) {
260                         printerr(0, "Key configuration has no file system "
261                                  "attribute.  Can't load as server type\n");
262                         goto out;
263                 }
264                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
265                               config->skc_fsname, config->skc_nodemap);
266                 if (rc >= SK_DESCRIPTION_SIZE)
267                         goto out;
268                 if (sk_load_key(config, description) == -1)
269                         goto out;
270         }
271         if (config->skc_type & SK_TYPE_CLIENT) {
272                 /* Load client file system key */
273                 if (config->skc_fsname) {
274                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
275                                       "lustre:%s", config->skc_fsname);
276                         if (rc >= SK_DESCRIPTION_SIZE)
277                                 goto out;
278                         if (sk_load_key(config, description) == -1)
279                                 goto out;
280                 }
281
282                 /* Load client MGC keys */
283                 for (i = 0; i < MAX_MGSNIDS; i++) {
284                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
285                                 continue;
286                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
287                                       "lustre:MGC%s",
288                                       libcfs_nid2str(config->skc_mgsnids[i]));
289                         if (rc >= SK_DESCRIPTION_SIZE)
290                                 goto out;
291                         if (sk_load_key(config, description) == -1)
292                                 goto out;
293                 }
294         }
295
296         rc2 = 0;
297
298 out:
299         free(config);
300         return rc2;
301 }
302
303 /**
304  * Byte swaps config from cpu format to disk
305  *
306  * \param[in,out]       config          sk_keyfile_config to swap
307  */
308 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
309 {
310         int i;
311
312         if (!config)
313                 return;
314
315         config->skc_version = htobe32(config->skc_version);
316         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
317         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
318         config->skc_expire = htobe32(config->skc_expire);
319         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
320         config->skc_prime_bits = htobe32(config->skc_prime_bits);
321
322         for (i = 0; i < MAX_MGSNIDS; i++)
323                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
324
325         return;
326 }
327
328 /**
329  * Byte swaps config from disk format to cpu
330  *
331  * \param[in,out]       config          sk_keyfile_config to swap
332  */
333 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
334 {
335         int i;
336
337         if (!config)
338                 return;
339
340         config->skc_version = be32toh(config->skc_version);
341         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
342         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
343         config->skc_expire = be32toh(config->skc_expire);
344         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
345         config->skc_prime_bits = be32toh(config->skc_prime_bits);
346
347         for (i = 0; i < MAX_MGSNIDS; i++)
348                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
349
350         return;
351 }
352
353 /**
354  * Verifies the on key payload format is valid
355  *
356  * \param[in]   config          sk_keyfile_config
357  *
358  * \return      -1      failure
359  * \return      0       success
360  */
361 int sk_validate_config(const struct sk_keyfile_config *config)
362 {
363         int i;
364
365         if (!config) {
366                 printerr(0, "Null configuration passed\n");
367                 return -1;
368         }
369         if (config->skc_version != SK_CONF_VERSION) {
370                 printerr(0, "Invalid version\n");
371                 return -1;
372         }
373         if ((config->skc_hmac_alg != CFS_HASH_ALG_SHA256) &&
374             (config->skc_hmac_alg != CFS_HASH_ALG_SHA512)) {
375                 printerr(0, "Invalid HMAC algorithm\n");
376                 return -1;
377         }
378         if (config->skc_crypt_alg >= SK_CRYPT_MAX) {
379                 printerr(0, "Invalid crypt algorithm\n");
380                 return -1;
381         }
382         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
383                 /* Try to limit key expiration to some reasonable minimum and
384                  * also prevent values over INT_MAX because there appears
385                  * to be a type conversion issue */
386                 printerr(0, "Invalid expiration time should be between %d "
387                          "and %d\n", 60, INT_MAX);
388                 return -1;
389         }
390         if (config->skc_prime_bits % 8 != 0 ||
391             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
392                 printerr(0, "Invalid session key length must be a multiple of 8"
393                          " and less then %d bits\n",
394                          SK_MAX_P_BYTES * 8);
395                 return -1;
396         }
397         if (config->skc_shared_keylen % 8 != 0 ||
398             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
399                 printerr(0, "Invalid shared key max length must be a multiple "
400                          "of 8 and less then %d bits\n",
401                          SK_MAX_KEYLEN_BYTES * 8);
402                 return -1;
403         }
404
405         /* Check for terminating nulls on strings */
406         for (i = 0; i < sizeof(config->skc_fsname) &&
407              config->skc_fsname[i] != '\0';  i++)
408                 ; /* empty loop */
409         if (i == sizeof(config->skc_fsname)) {
410                 printerr(0, "File system name not null terminated\n");
411                 return -1;
412         }
413
414         for (i = 0; i < sizeof(config->skc_nodemap) &&
415              config->skc_nodemap[i] != '\0';  i++)
416                 ; /* empty loop */
417         if (i == sizeof(config->skc_nodemap)) {
418                 printerr(0, "Nodemap name not null terminated\n");
419                 return -1;
420         }
421
422         if (config->skc_type == SK_TYPE_INVALID) {
423                 printerr(0, "Invalid key type\n");
424                 return -1;
425         }
426
427         return 0;
428 }
429
430 /**
431  * Hashes \a string and places the hash in \a hash
432  * at \a hash
433  *
434  * \param[in]           string          Null terminated string to hash
435  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
436  * \param[in,out]       hash            gss_buffer_desc to hold the result
437  *
438  * \return      -1      failure
439  * \return      0       success
440  */
441 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
442                           gss_buffer_desc *hash)
443 {
444         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
445         size_t len = strlen(string);
446         unsigned int hashlen;
447
448         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
449                 goto out_err;
450         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
451                 goto out_err;
452         if (!EVP_DigestUpdate(ctx, string, len))
453                 goto out_err;
454         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
455                 goto out_err;
456
457         EVP_MD_CTX_destroy(ctx);
458         hash->length = hashlen;
459         return 0;
460
461 out_err:
462         EVP_MD_CTX_destroy(ctx);
463         return -1;
464 }
465
466 /**
467  * Hashes \a string and verifies the resulting hash matches the value
468  * in \a current_hash
469  *
470  * \param[in]           string          Null terminated string to hash
471  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
472  * \param[in,out]       current_hash    gss_buffer_desc to compare to
473  *
474  * \return      gss error       failure
475  * \return      GSS_S_COMPLETE  success
476  */
477 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
478                         const gss_buffer_desc *current_hash)
479 {
480         gss_buffer_desc hash;
481         unsigned char hashbuf[EVP_MAX_MD_SIZE];
482
483         hash.value = hashbuf;
484         hash.length = sizeof(hashbuf);
485
486         if (sk_hash_string(string, hash_alg, &hash))
487                 return GSS_S_FAILURE;
488         if (current_hash->length != hash.length)
489                 return GSS_S_DEFECTIVE_TOKEN;
490         if (memcmp(current_hash->value, hash.value, hash.length))
491                 return GSS_S_BAD_SIG;
492
493         return GSS_S_COMPLETE;
494 }
495
496 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
497                                        const char *mgsnid)
498 {
499         lnet_nid_t nid;
500         int i;
501
502         nid = libcfs_str2nid(mgsnid);
503         if (nid == LNET_NID_ANY)
504                 return 0;
505
506         for (i = 0; i < MAX_MGSNIDS; i++)
507                 if  (config->skc_mgsnids[i] == nid)
508                         return 1;
509         return 0;
510 }
511
512 /**
513  * Create an sk_cred structure populated with initial configuration info and the
514  * key.  \a tgt and \a nodemap are used in determining the expected key
515  * description so the key can be found by searching the keyring.
516  * This is done because there is no easy way to pass keys from the mount command
517  * all the way to the request_key call.  In addition any keys can be dynamically
518  * added to the keyrings and still found.  The keyring that needs to be used
519  * must be the session keyring.
520  *
521  * \param[in]   tgt             Target file system
522  * \param[in]   nodemap         Cluster name for the key.  This correlates to
523  *                              the nodemap name and is used by the server side.
524  *                              For the client this will be NULL.
525  * \param[in]   flags           Flags for the credentials
526  *
527  * \return      sk_cred Allocated struct sk_cred on success
528  * \return      NULL    failure
529  */
530 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
531                                const uint32_t flags)
532 {
533         struct sk_keyfile_config *config;
534         struct sk_kernel_ctx *kctx;
535         struct sk_cred *skc = NULL;
536         char description[SK_DESCRIPTION_SIZE + 1];
537         char fsname[MTI_NAME_MAXLEN + 1];
538         const char *mgsnid = NULL;
539         char *ptr;
540         long sk_key;
541         int keylen;
542         int len;
543         int rc;
544
545         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
546                  tgt, nodemap);
547
548         memset(description, 0, sizeof(description));
549         memset(fsname, 0, sizeof(fsname));
550
551         /* extract the file system name from target */
552         ptr = index(tgt, '-');
553         if (!ptr) {
554                 len = strlen(tgt);
555
556                 /* This must be an MGC target */
557                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
558                         printerr(0, "Invalid target name\n");
559                         return NULL;
560                 }
561                 mgsnid = tgt + 3;
562         } else {
563                 len = ptr - tgt;
564         }
565
566         if (len > MTI_NAME_MAXLEN) {
567                 printerr(0, "Invalid target name\n");
568                 return NULL;
569         }
570         memcpy(fsname, tgt, len);
571
572         if (nodemap) {
573                 if (mgsnid)
574                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
575                                       "lustre:MGS:%s", nodemap);
576                 else
577                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
578                                       "lustre:%s:%s", fsname, nodemap);
579         } else {
580                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
581                               fsname);
582         }
583
584         if (rc >= SK_DESCRIPTION_SIZE) {
585                 printerr(0, "Invalid key description\n");
586                 return NULL;
587         }
588
589         /* It may be a good idea to move Lustre keys to the gss_keyring
590          * (lgssc) type so that they expire when Lustre modules are removed.
591          * Unfortunately it can't be done at mount time because the mount
592          * syscall could trigger the Lustre modules to load and until that
593          * point we don't have a lgssc key type.
594          *
595          * TODO: Query the community for a consensus here  */
596         printerr(2, "Searching for key with description: %s\n", description);
597         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
598                                description, 0);
599         if (sk_key == -1) {
600                 printerr(1, "No key found for %s\n", description);
601                 return NULL;
602         }
603
604         keylen = keyctl_read_alloc(sk_key, (void **)&config);
605         if (keylen == -1) {
606                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
607                          strerror(errno));
608                 return NULL;
609         } else if (keylen != sizeof(*config)) {
610                 printerr(0, "Unexpected key size: %d returned for key %ld, "
611                          "expected %zu bytes\n",
612                          keylen, sk_key, sizeof(*config));
613                 goto out_err;
614         }
615
616         sk_config_disk_to_cpu(config);
617
618         if (sk_validate_config(config)) {
619                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
620                 goto out_err;
621         }
622
623         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
624                 printerr(0, "Target name does not match key's MGS NIDs\n");
625                 goto out_err;
626         }
627
628         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
629                 printerr(0, "Target name does not match key's file system\n");
630                 goto out_err;
631         }
632
633         skc = malloc(sizeof(*skc));
634         if (!skc) {
635                 printerr(0, "Failed to allocate memory for sk_cred\n");
636                 goto out_err;
637         }
638
639         /* this initializes all gss_buffer_desc to empty as well */
640         memset(skc, 0, sizeof(*skc));
641
642         skc->sc_flags = flags;
643         skc->sc_tgt.length = strlen(tgt) + 1;
644         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
645         if (!skc->sc_tgt.value) {
646                 printerr(0, "Failed to allocate memory for target\n");
647                 goto out_err;
648         }
649         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
650
651         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
652         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
653         if (!skc->sc_nodemap_hash.value) {
654                 printerr(0, "Failed to allocate memory for nodemap hash\n");
655                 goto out_err;
656         }
657
658         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
659                            &skc->sc_nodemap_hash)) {
660                 printerr(0, "Failed to generate hash for nodemap name\n");
661                 goto out_err;
662         }
663
664         kctx = &skc->sc_kctx;
665         kctx->skc_version = config->skc_version;
666         kctx->skc_hmac_alg = config->skc_hmac_alg;
667         kctx->skc_crypt_alg = config->skc_crypt_alg;
668         kctx->skc_expire = config->skc_expire;
669
670         /* key payload format is in bits, convert to bytes */
671         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
672         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
673         if (!kctx->skc_shared_key.value) {
674                 printerr(0, "Failed to allocate memory for shared key\n");
675                 goto out_err;
676         }
677         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
678                kctx->skc_shared_key.length);
679
680         skc->sc_p.length = config->skc_prime_bits / 8;
681         skc->sc_p.value = malloc(skc->sc_p.length);
682         if (!skc->sc_p.value) {
683                 printerr(0, "Failed to allocate p\n");
684                 goto out_err;
685         }
686         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
687
688         free(config);
689
690         return skc;
691
692 out_err:
693         sk_free_cred(skc);
694
695         free(config);
696         return NULL;
697 }
698
699 /**
700  * Populates the DH parameters for the DHKE
701  *
702  * \param[in,out]       skc             Shared key credentials structure to
703  *                                      populate with DH parameters
704  *
705  * \retval      GSS_S_COMPLETE  success
706  * \retval      GSS_S_FAILURE   failure
707  */
708 uint32_t sk_gen_params(struct sk_cred *skc)
709 {
710         uint32_t random;
711         int rc;
712
713         /* Random value used by both the request and response as part of the
714          * key binding material.  This also should ensure we have unqiue
715          * tokens that are sent to the remote server which is important because
716          * the token is hashed for the sunrpc cache lookups and a failure there
717          * would cause connection attempts to fail indefinitely due to the large
718          * timeout value on the server side */
719         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
720                 printerr(0, "Failed to get data for random parameter: %s\n",
721                          ERR_error_string(ERR_get_error(), NULL));
722                 return GSS_S_FAILURE;
723         }
724
725         /* The random value will always be used in byte range operations
726          * so we keep it as big endian from this point on */
727         skc->sc_kctx.skc_host_random = random;
728
729         /* Populate DH parameters */
730         skc->sc_params = DH_new();
731         if (!skc->sc_params) {
732                 printerr(0, "Failed to allocate DH\n");
733                 return GSS_S_FAILURE;
734         }
735
736         skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
737         if (!skc->sc_params->p) {
738                 printerr(0, "Failed to convert binary to BIGNUM\n");
739                 return GSS_S_FAILURE;
740         }
741
742         /* We use a static generator for shared key */
743         skc->sc_params->g = BN_new();
744         if (!skc->sc_params->g) {
745                 printerr(0, "Failed to allocate new BIGNUM\n");
746                 return GSS_S_FAILURE;
747         }
748         if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
749                 printerr(0, "Failed to set g value for DH params\n");
750                 return GSS_S_FAILURE;
751         }
752
753         /* Verify that we have a safe prime and valid generator */
754         if (DH_check(skc->sc_params, &rc) != 1) {
755                 printerr(0, "DH_check() failed: %d\n", rc);
756                 return GSS_S_FAILURE;
757         } else if (rc) {
758                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
759                 return GSS_S_FAILURE;
760         }
761
762         if (DH_generate_key(skc->sc_params) != 1) {
763                 printerr(0, "Failed to generate public DH key: %s\n",
764                          ERR_error_string(ERR_get_error(), NULL));
765                 return GSS_S_FAILURE;
766         }
767
768         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
769         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
770         if (!skc->sc_pub_key.value) {
771                 printerr(0, "Failed to allocate memory for public key\n");
772                 return GSS_S_FAILURE;
773         }
774
775         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
776
777         return GSS_S_COMPLETE;
778 }
779
780 /**
781  * Convert SK hash algorithm into openssl message digest
782  *
783  * \param[in,out]       alg             SK hash algorithm
784  *
785  * \retval              EVP_MD
786  */
787 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
788 {
789         switch (alg) {
790         case CFS_HASH_ALG_SHA256:
791                 return EVP_sha256();
792         case CFS_HASH_ALG_SHA512:
793                 return EVP_sha512();
794         default:
795                 return EVP_md_null();
796         }
797 }
798
799 /**
800  * Signs (via HMAC) the parameters used only in the key initialization protocol.
801  *
802  * \param[in]           key             Key to use for HMAC
803  * \param[in]           bufs            Array of gss_buffer_desc to generate
804  *                                      HMAC for
805  * \param[in]           numbufs         Number of buffers in array
806  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
807  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
808  *                                      in this gss_buffer_desc.  Caller must
809  *                                      free this.
810  *
811  * \retval      0       success
812  * \retval      -1      failure
813  */
814 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
815                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
816 {
817         HMAC_CTX hctx;
818         unsigned int hashlen = EVP_MD_size(hash_alg);
819         int i;
820         int rc = -1;
821
822         if (hash_alg == EVP_md_null()) {
823                 printerr(0, "Invalid hash algorithm\n");
824                 return -1;
825         }
826
827         HMAC_CTX_init(&hctx);
828
829         hmac->length = hashlen;
830         hmac->value = malloc(hashlen);
831         if (!hmac->value) {
832                 printerr(0, "Failed to allocate memory for HMAC\n");
833                 goto out;
834         }
835
836         if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
837                 printerr(0, "Failed to init HMAC\n");
838                 goto out;
839         }
840
841         for (i = 0; i < numbufs; i++) {
842                 if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
843                         printerr(0, "Failed to update HMAC\n");
844                         goto out;
845                 }
846         }
847
848         /* The result gets populated in hmac */
849         if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
850                 printerr(0, "Failed to finalize HMAC\n");
851                 goto out;
852         }
853
854         if (hmac->length != hashlen) {
855                 printerr(0, "HMAC size does not match expected\n");
856                 goto out;
857         }
858
859         rc = 0;
860 out:
861         HMAC_CTX_cleanup(&hctx);
862         return rc;
863 }
864
865 /**
866  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
867  * and verifies against \a hmac.
868  *
869  * \param[in]   skc             Shared key credentials
870  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
871  * \param[in]   numbufs         Number of buffers in array
872  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
873  * \param[in]   hmac            HMAC to verify against
874  *
875  * \retval      GSS_S_COMPLETE  success (match)
876  * \retval      gss error       failure
877  */
878 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
879                         const int numbufs, const EVP_MD *hash_alg,
880                         gss_buffer_desc *hmac)
881 {
882         gss_buffer_desc bufs_hmac;
883         int rc;
884
885         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
886                          &bufs_hmac)) {
887                 printerr(0, "Failed to sign buffers to verify HMAC\n");
888                 if (bufs_hmac.value)
889                         free(bufs_hmac.value);
890                 return GSS_S_FAILURE;
891         }
892
893         if (hmac->length != bufs_hmac.length) {
894                 printerr(0, "Invalid HMAC size\n");
895                 free(bufs_hmac.value);
896                 return GSS_S_BAD_SIG;
897         }
898
899         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
900         free(bufs_hmac.value);
901
902         if (rc)
903                 return GSS_S_BAD_SIG;
904
905         return GSS_S_COMPLETE;
906 }
907
908 /**
909  * Cleanup an sk_cred freeing any resources
910  *
911  * \param[in,out]       skc     Shared key credentials to free
912  */
913 void sk_free_cred(struct sk_cred *skc)
914 {
915         if (!skc)
916                 return;
917
918         if (skc->sc_p.value)
919                 free(skc->sc_p.value);
920         if (skc->sc_pub_key.value)
921                 free(skc->sc_pub_key.value);
922         if (skc->sc_tgt.value)
923                 free(skc->sc_tgt.value);
924         if (skc->sc_nodemap_hash.value)
925                 free(skc->sc_nodemap_hash.value);
926         if (skc->sc_hmac.value)
927                 free(skc->sc_hmac.value);
928
929         /* Overwrite keys and IV before freeing */
930         if (skc->sc_dh_shared_key.value) {
931                 memset(skc->sc_dh_shared_key.value, 0,
932                        skc->sc_dh_shared_key.length);
933                 free(skc->sc_dh_shared_key.value);
934         }
935         if (skc->sc_kctx.skc_hmac_key.value) {
936                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
937                        skc->sc_kctx.skc_hmac_key.length);
938                 free(skc->sc_kctx.skc_hmac_key.value);
939         }
940         if (skc->sc_kctx.skc_encrypt_key.value) {
941                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
942                        skc->sc_kctx.skc_encrypt_key.length);
943                 free(skc->sc_kctx.skc_encrypt_key.value);
944         }
945         if (skc->sc_kctx.skc_shared_key.value) {
946                 memset(skc->sc_kctx.skc_shared_key.value, 0,
947                        skc->sc_kctx.skc_shared_key.length);
948                 free(skc->sc_kctx.skc_shared_key.value);
949         }
950         if (skc->sc_kctx.skc_session_key.value) {
951                 memset(skc->sc_kctx.skc_session_key.value, 0,
952                        skc->sc_kctx.skc_session_key.length);
953                 free(skc->sc_kctx.skc_session_key.value);
954         }
955
956         if (skc->sc_params)
957                 DH_free(skc->sc_params);
958
959         free(skc);
960         skc = NULL;
961 }
962
963 /* This function handles key derivation using the hash algorithm specified in
964  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
965  * \a origin_key to produce a \a derived_key.  The first element of the
966  * key_binding_bufs array is reserved for the counter used in the KDF.  The
967  * derived key in \a derived_key could differ in size from \a origin_key and
968  * must be populated with the expected size and a valid buffer to hold the
969  * contents.
970  *
971  * If the derived key size is greater than the HMAC algorithm size it will be
972  * a done using several iterations of a counter and the key binding bufs.
973  *
974  * If the size is smaller it will take copy the first N bytes necessary to
975  * fill the derived key. */
976 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
977            gss_buffer_desc *key_binding_bufs, int numbufs,
978            enum cfs_crypto_hash_alg hmac_alg)
979 {
980         size_t remain;
981         size_t bytes;
982         uint32_t counter;
983         char *keydata;
984         gss_buffer_desc tmp_hash;
985         int i;
986         int rc;
987
988         if (numbufs < 1)
989                 return -EINVAL;
990
991         /* Use a counter as the first buffer followed by the key binding
992          * buffers in the event we need more than one a single cycle to
993          * produced a symmetric key large enough in size */
994         key_binding_bufs[0].value = &counter;
995         key_binding_bufs[0].length = sizeof(counter);
996
997         remain = derived_key->length;
998         keydata = derived_key->value;
999         i = 0;
1000         while (remain > 0) {
1001                 counter = htobe32(i++);
1002                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1003                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1004                 if (rc) {
1005                         if (tmp_hash.value)
1006                                 free(tmp_hash.value);
1007                         return rc;
1008                 }
1009
1010                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
1011                         free(tmp_hash.value);
1012                         return -EINVAL;
1013                 }
1014
1015                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1016                 memcpy(keydata, tmp_hash.value, bytes);
1017                 free(tmp_hash.value);
1018                 remain -= bytes;
1019                 keydata += bytes;
1020         }
1021
1022         return 0;
1023 }
1024
1025 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1026  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1027  * (Sep 2014) Section 5.5.1
1028  *
1029  * \param[in,out]       skc             Shared key credentials structure with
1030  *
1031  * \return      -1              failure
1032  * \return      0               success
1033  */
1034 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1035                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1036 {
1037         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1038         gss_buffer_desc *session_key = &kctx->skc_session_key;
1039         gss_buffer_desc bufs[5];
1040         int rc = -1;
1041
1042         session_key->length = sk_crypt_types[kctx->skc_crypt_alg].cht_bytes;
1043         session_key->value = malloc(session_key->length);
1044         if (!session_key->value) {
1045                 printerr(0, "Failed to allocate memory for session key\n");
1046                 return rc;
1047         }
1048
1049         /* Key binding info ordering
1050          * 1. Reserved for counter
1051          * 1. DH shared key
1052          * 2. Client's NIDs
1053          * 3. Client's token
1054          * 4. Server's token */
1055         bufs[0].value = NULL;
1056         bufs[0].length = 0;
1057         bufs[1] = skc->sc_dh_shared_key;
1058         bufs[2].value = &client_nid;
1059         bufs[2].length = sizeof(client_nid);
1060         bufs[3] = *client_token;
1061         bufs[4] = *server_token;
1062
1063         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1064                       5, kctx->skc_hmac_alg);
1065 }
1066
1067 /* Uses the session key to create an HMAC key and encryption key.  In
1068  * integrity mode the session key used to generate the HMAC key uses
1069  * session information which is available on the wire but by creating
1070  * a session based HMAC key we can prevent potential replay as both the
1071  * client and server have random numbers used as part of the key creation.
1072  *
1073  * The keys used for integrity and privacy are formulated as below using
1074  * the session key that is the output of the key derivation function.  The
1075  * HMAC algorithm is determined by the shared key algorithm selected in the
1076  * key file.
1077  *
1078  * For ski mode:
1079  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1080  *
1081  * For skpi mode:
1082  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1083  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1084  *
1085  * \param[in,out]       skc             Shared key credentials structure with
1086  *
1087  * \return      -1              failure
1088  * \return      0               success
1089  */
1090 int sk_compute_keys(struct sk_cred *skc)
1091 {
1092         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1093         gss_buffer_desc *session_key = &kctx->skc_session_key;
1094         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1095         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1096         char *encrypt = "Encrypt";
1097         char *integrity = "Integrity";
1098         int rc;
1099
1100         hmac_key->length = cfs_crypto_hash_digestsize(kctx->skc_hmac_alg);
1101         hmac_key->value = malloc(hmac_key->length);
1102         if (!hmac_key->value)
1103                 return -ENOMEM;
1104
1105         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1106                                session_key->length, SK_PBKDF2_ITERATIONS,
1107                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1108                                hmac_key->length, hmac_key->value);
1109         if (rc == 0)
1110                 return -EINVAL;
1111
1112         /* Encryption key is only populated in privacy mode */
1113         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1114                 return 0;
1115
1116         encrypt_key->length = cfs_crypto_hash_digestsize(kctx->skc_hmac_alg);
1117         encrypt_key->value = malloc(encrypt_key->length);
1118         if (!encrypt_key->value)
1119                 return -ENOMEM;
1120
1121         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1122                                session_key->length, SK_PBKDF2_ITERATIONS,
1123                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1124                                encrypt_key->length, encrypt_key->value);
1125         if (rc == 0)
1126                 return -EINVAL;
1127
1128         return 0;
1129 }
1130
1131 /**
1132  * Computes a session key based on the DH parameters from the host and its peer
1133  *
1134  * \param[in,out]       skc             Shared key credentials structure with
1135  *                                      the session key populated with the
1136  *                                      compute key
1137  * \param[in]           pub_key         Public key returned from peer in
1138  *                                      gss_buffer_desc
1139  * \return      gss error               failure
1140  * \return      GSS_S_COMPLETE          success
1141  */
1142 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1143 {
1144         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1145         BIGNUM *remote_pub_key;
1146         int status;
1147         uint32_t rc = GSS_S_FAILURE;
1148
1149         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1150         if (!remote_pub_key) {
1151                 printerr(0, "Failed to convert binary to BIGNUM\n");
1152                 return rc;
1153         }
1154
1155         dh_shared->length = DH_size(skc->sc_params);
1156         dh_shared->value = malloc(dh_shared->length);
1157         if (!dh_shared->value) {
1158                 printerr(0, "Failed to allocate memory for computed shared "
1159                          "secret key\n");
1160                 goto out_err;
1161         }
1162
1163         /* This compute the shared key from the DHKE */
1164         status = DH_compute_key(dh_shared->value, remote_pub_key,
1165                                 skc->sc_params);
1166         if (status == -1) {
1167                 printerr(0, "DH_compute_key() failed: %s\n",
1168                          ERR_error_string(ERR_get_error(), NULL));
1169                 goto out_err;
1170         } else if (status < dh_shared->length) {
1171                 printerr(0, "DH_compute_key() returned a short key of %d "
1172                          "bytes, expected: %zu\n", status, dh_shared->length);
1173                 rc = GSS_S_DEFECTIVE_TOKEN;
1174                 goto out_err;
1175         }
1176
1177         rc = GSS_S_COMPLETE;
1178
1179 out_err:
1180         BN_free(remote_pub_key);
1181         return rc;
1182 }
1183
1184 /**
1185  * Creates a serialized buffer for the kernel in the order of struct
1186  * sk_kernel_ctx.
1187  *
1188  * \param[in,out]       skc             Shared key credentials structure
1189  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1190  *                                      Caller must free this buffer.
1191  *
1192  * \return      0       success
1193  * \return      -1      failure
1194  */
1195 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1196 {
1197         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1198         char *p, *end;
1199         size_t bufsize;
1200
1201         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1202                   kctx->skc_encrypt_key.length;
1203
1204         ctx_token->value = malloc(bufsize);
1205         if (!ctx_token->value)
1206                 return -1;
1207         ctx_token->length = bufsize;
1208
1209         p = ctx_token->value;
1210         end = p + ctx_token->length;
1211
1212         if (WRITE_BYTES(&p, end, kctx->skc_version))
1213                 return -1;
1214         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1215                 return -1;
1216         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1217                 return -1;
1218         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1219                 return -1;
1220         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1221                 return -1;
1222         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1223                 return -1;
1224         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1225                 return -1;
1226         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1227                 return -1;
1228
1229         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1230
1231         return 0;
1232 }
1233
1234 /**
1235  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1236  * up to \a numbufs.  Memory is allocated for each value and length
1237  * will be populated with the length
1238  *
1239  * \param[in,out]       bufs    Array of gss_buffer_descs
1240  * \param[in,out]       numbufs number of gss_buffer_desc in array
1241  * \param[in]           ns      netstring to decode
1242  *
1243  * \return      buffers populated       success
1244  * \return      -1                      failure
1245  */
1246 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1247 {
1248         char *ptr = ns->value;
1249         size_t remain = ns->length;
1250         unsigned int size;
1251         int digits;
1252         int sep;
1253         int rc;
1254         int i;
1255
1256         for (i = 0; i < numbufs; i++) {
1257                 /* read the size of first buffer */
1258                 rc = sscanf(ptr, "%9u", &size);
1259                 if (rc < 1)
1260                         goto out_err;
1261                 digits = (size) ? ceil(log10(size + 1)) : 1;
1262
1263                 /* sep of current string */
1264                 sep = size + digits + 2;
1265
1266                 /* check to make sure it's valid */
1267                 if (remain < sep || ptr[digits] != ':' ||
1268                     ptr[sep - 1] != ',')
1269                         goto out_err;
1270
1271                 bufs[i].length = size;
1272                 if (size == 0) {
1273                         bufs[i].value = NULL;
1274                 } else {
1275                         bufs[i].value = malloc(size);
1276                         if (!bufs[i].value)
1277                                 goto out_err;
1278                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1279                 }
1280
1281                 remain -= sep;
1282                 ptr += sep;
1283         }
1284
1285         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1286         return i;
1287
1288 out_err:
1289         while (i-- > 0) {
1290                 if (bufs[i].value)
1291                         free(bufs[i].value);
1292                 bufs[i].length = 0;
1293         }
1294         return -1;
1295 }
1296
1297 /**
1298  * Creates a netstring in a gss_buffer_desc that consists of all
1299  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1300  * as binary as it can contain null characters.
1301  *
1302  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1303  * \param[in]           numbufs         Number of buffers in array
1304  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1305  *                                      netstring
1306  *
1307  * \return      -1      failure
1308  * \return      0       success
1309  */
1310 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1311                         gss_buffer_desc *ns)
1312 {
1313         unsigned char *ptr;
1314         int size = 0;
1315         int rc;
1316         int i;
1317
1318         /* size of string in decimal, string size, colon, and comma */
1319         for (i = 0; i < numbufs; i++) {
1320
1321                 if (bufs[i].length == 0)
1322                         size += 3;
1323                 else
1324                         size += ceil(log10(bufs[i].length + 1)) +
1325                                 bufs[i].length + 2;
1326         }
1327
1328         ns->length = size;
1329         ns->value = malloc(ns->length);
1330         if (!ns->value) {
1331                 ns->length = 0;
1332                 return -1;
1333         }
1334
1335         ptr = ns->value;
1336         for (i = 0; i < numbufs; i++) {
1337                 /* size */
1338                 rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
1339                 ptr += rc;
1340
1341                 /* contents */
1342                 memcpy(ptr, bufs[i].value, bufs[i].length);
1343                 ptr += bufs[i].length;
1344
1345                 /* delimeter */
1346                 *ptr++ = ',';
1347
1348                 size -= bufs[i].length + rc + 1;
1349
1350                 /* should not happen */
1351                 if (size < 0)
1352                         abort();
1353         }
1354
1355         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1356         return 0;
1357 }