Whamcloud - gitweb
LU-8901 misc: update Intel copyright messages for 2016
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42 #include <lnet/nidstr.h>
43
44 #include "sk_utils.h"
45 #include "write_bytes.h"
46
47 #define SK_PBKDF2_ITERATIONS 10000
48
49 static struct sk_crypt_type sk_crypt_types[] = {
50         [SK_CRYPT_AES256_CTR] = {
51                 .sct_name = "ctr(aes)",
52                 .sct_bytes = 32,
53         },
54 };
55
56 static struct sk_hmac_type sk_hmac_types[] = {
57         [SK_HMAC_SHA256] = {
58                 .sht_name = "hmac(sha256)",
59                 .sht_bytes = 32,
60         },
61         [SK_HMAC_SHA512] = {
62                 .sht_name = "hmac(sha512)",
63                 .sht_bytes = 64,
64         },
65 };
66
67 #ifdef _NEW_BUILD_
68 # include "lgss_utils.h"
69 #else
70 # include "gss_util.h"
71 # include "gss_oids.h"
72 # include "err_util.h"
73 #endif
74
75 #ifdef _ERR_UTIL_H_
76 /**
77  * Initializes logging
78  * \param[in]   program         Program name to output
79  * \param[in]   verbose         Verbose flag
80  * \param[in]   fg              Whether or not to run in foreground
81  *
82  */
83 void sk_init_logging(char *program, int verbose, int fg)
84 {
85         initerr(program, verbose, fg);
86 }
87 #endif
88
89 /**
90  * Loads the key from \a filename and returns the struct sk_keyfile_config.
91  * It should be freed by the caller.
92  *
93  * \param[in]   filename                Disk or key payload data
94  *
95  * \return      sk_keyfile_config       sucess
96  * \return      NULL                    failure
97  */
98 struct sk_keyfile_config *sk_read_file(char *filename)
99 {
100         struct sk_keyfile_config *config;
101         char *ptr;
102         size_t rc;
103         size_t remain;
104         int fd;
105
106         config = malloc(sizeof(*config));
107         if (!config) {
108                 printerr(0, "Failed to allocate memory for config\n");
109                 return NULL;
110         }
111
112         /* allow standard input override */
113         if (strcmp(filename, "-") == 0)
114                 fd = STDIN_FILENO;
115         else
116                 fd = open(filename, O_RDONLY);
117
118         if (fd == -1) {
119                 printerr(0, "Error opening key file '%s': %s\n", filename,
120                          strerror(errno));
121                 goto out_free;
122         } else if (fd != STDIN_FILENO) {
123                 struct stat st;
124
125                 rc = fstat(fd, &st);
126                 if (rc == 0 && (st.st_mode & ~0600))
127                         fprintf(stderr, "warning: "
128                                 "secret key '%s' has insecure file mode %#o\n",
129                                 filename, st.st_mode);
130         }
131
132         ptr = (char *)config;
133         remain = sizeof(*config);
134         while (remain > 0) {
135                 rc = read(fd, ptr, remain);
136                 if (rc == -1) {
137                         if (errno == EINTR)
138                                 continue;
139                         printerr(0, "read() failed on %s: %s\n", filename,
140                                  strerror(errno));
141                         goto out_close;
142                 } else if (rc == 0) {
143                         printerr(0, "File %s does not have a complete key\n",
144                                  filename);
145                         goto out_close;
146                 }
147                 ptr += rc;
148                 remain -= rc;
149         }
150
151         if (fd != STDIN_FILENO)
152                 close(fd);
153         sk_config_disk_to_cpu(config);
154         return config;
155
156 out_close:
157         close(fd);
158 out_free:
159         free(config);
160         return NULL;
161 }
162
163 /**
164  * Checks if a key matching \a description is found in the keyring for
165  * logging purposes and then attempts to load \a payload of \a psize into a key
166  * with \a description.
167  *
168  * \param[in]   payload         Key payload
169  * \param[in]   psize           Payload size
170  * \param[in]   description     Description used for key in keyring
171  *
172  * \return      0       sucess
173  * \return      -1      failure
174  */
175 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
176                                 const char *description)
177 {
178         struct sk_keyfile_config payload;
179         key_serial_t key;
180
181         memcpy(&payload, skc, sizeof(*skc));
182
183         /* In the keyring use the disk layout so keyctl pipe can be used */
184         sk_config_cpu_to_disk(&payload);
185
186         /* Check to see if a key is already loaded matching description */
187         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
188         if (key != -1)
189                 printerr(2, "Key %d found in session keyring, replacing\n",
190                          key);
191
192         key = add_key("user", description, &payload, sizeof(payload),
193                       KEY_SPEC_USER_KEYRING);
194         if (key != -1)
195                 printerr(2, "Added key %d with description %s\n", key,
196                          description);
197         else
198                 printerr(0, "Failed to add key with %s\n", description);
199
200         return key;
201 }
202
203 /**
204  * Reads the key from \a path, verifies it and loads into the session keyring
205  * using a description determined by the the \a type.  Existing keys with the
206  * same description are replaced.
207  *
208  * \param[in]   path    Path to key file
209  * \param[in]   type    Type of key to load which determines the description
210  *
211  * \return      0       sucess
212  * \return      -1      failure
213  */
214 int sk_load_keyfile(char *path)
215 {
216         struct sk_keyfile_config *config;
217         char description[SK_DESCRIPTION_SIZE + 1];
218         struct stat buf;
219         int i;
220         int rc;
221         int rc2 = -1;
222
223         rc = stat(path, &buf);
224         if (rc == -1) {
225                 printerr(0, "stat() failed for file %s: %s\n", path,
226                          strerror(errno));
227                 return rc2;
228         }
229
230         config = sk_read_file(path);
231         if (!config)
232                 return rc2;
233
234         /* Similar to ssh, require adequate care of key files */
235         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
236                 printerr(0, "Shared key files must be read/writeable only by "
237                          "owner\n");
238                 return -1;
239         }
240
241         if (sk_validate_config(config))
242                 goto out;
243
244         /* The server side can have multiple key files per file system so
245          * the nodemap name is appended to the key description to uniquely
246          * identify it */
247         if (config->skc_type & SK_TYPE_MGS) {
248                 /* Any key can be an MGS key as long as we are told to use it */
249                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
250                               config->skc_nodemap);
251                 if (rc >= SK_DESCRIPTION_SIZE)
252                         goto out;
253                 if (sk_load_key(config, description) == -1)
254                         goto out;
255         }
256         if (config->skc_type & SK_TYPE_SERVER) {
257                 /* Server keys need to have the file system name in the key */
258                 if (!config->skc_fsname) {
259                         printerr(0, "Key configuration has no file system "
260                                  "attribute.  Can't load as server type\n");
261                         goto out;
262                 }
263                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
264                               config->skc_fsname, config->skc_nodemap);
265                 if (rc >= SK_DESCRIPTION_SIZE)
266                         goto out;
267                 if (sk_load_key(config, description) == -1)
268                         goto out;
269         }
270         if (config->skc_type & SK_TYPE_CLIENT) {
271                 /* Load client file system key */
272                 if (config->skc_fsname) {
273                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
274                                       "lustre:%s", config->skc_fsname);
275                         if (rc >= SK_DESCRIPTION_SIZE)
276                                 goto out;
277                         if (sk_load_key(config, description) == -1)
278                                 goto out;
279                 }
280
281                 /* Load client MGC keys */
282                 for (i = 0; i < MAX_MGSNIDS; i++) {
283                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
284                                 continue;
285                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
286                                       "lustre:MGC%s",
287                                       libcfs_nid2str(config->skc_mgsnids[i]));
288                         if (rc >= SK_DESCRIPTION_SIZE)
289                                 goto out;
290                         if (sk_load_key(config, description) == -1)
291                                 goto out;
292                 }
293         }
294
295         rc2 = 0;
296
297 out:
298         free(config);
299         return rc2;
300 }
301
302 /**
303  * Byte swaps config from cpu format to disk
304  *
305  * \param[in,out]       config          sk_keyfile_config to swap
306  */
307 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
308 {
309         int i;
310
311         if (!config)
312                 return;
313
314         config->skc_version = htobe32(config->skc_version);
315         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
316         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
317         config->skc_expire = htobe32(config->skc_expire);
318         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
319         config->skc_prime_bits = htobe32(config->skc_prime_bits);
320
321         for (i = 0; i < MAX_MGSNIDS; i++)
322                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
323
324         return;
325 }
326
327 /**
328  * Byte swaps config from disk format to cpu
329  *
330  * \param[in,out]       config          sk_keyfile_config to swap
331  */
332 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
333 {
334         int i;
335
336         if (!config)
337                 return;
338
339         config->skc_version = be32toh(config->skc_version);
340         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
341         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
342         config->skc_expire = be32toh(config->skc_expire);
343         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
344         config->skc_prime_bits = be32toh(config->skc_prime_bits);
345
346         for (i = 0; i < MAX_MGSNIDS; i++)
347                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
348
349         return;
350 }
351
352 /**
353  * Verifies the on key payload format is valid
354  *
355  * \param[in]   config          sk_keyfile_config
356  *
357  * \return      -1      failure
358  * \return      0       success
359  */
360 int sk_validate_config(const struct sk_keyfile_config *config)
361 {
362         int i;
363
364         if (!config) {
365                 printerr(0, "Null configuration passed\n");
366                 return -1;
367         }
368         if (config->skc_version != SK_CONF_VERSION) {
369                 printerr(0, "Invalid version\n");
370                 return -1;
371         }
372         if (config->skc_hmac_alg >= SK_HMAC_MAX) {
373                 printerr(0, "Invalid HMAC algorithm\n");
374                 return -1;
375         }
376         if (config->skc_crypt_alg >= SK_CRYPT_MAX) {
377                 printerr(0, "Invalid crypt algorithm\n");
378                 return -1;
379         }
380         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
381                 /* Try to limit key expiration to some reasonable minimum and
382                  * also prevent values over INT_MAX because there appears
383                  * to be a type conversion issue */
384                 printerr(0, "Invalid expiration time should be between %d "
385                          "and %d\n", 60, INT_MAX);
386                 return -1;
387         }
388         if (config->skc_prime_bits % 8 != 0 ||
389             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
390                 printerr(0, "Invalid session key length must be a multiple of 8"
391                          " and less then %d bits\n",
392                          SK_MAX_P_BYTES * 8);
393                 return -1;
394         }
395         if (config->skc_shared_keylen % 8 != 0 ||
396             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
397                 printerr(0, "Invalid shared key max length must be a multiple "
398                          "of 8 and less then %d bits\n",
399                          SK_MAX_KEYLEN_BYTES * 8);
400                 return -1;
401         }
402
403         /* Check for terminating nulls on strings */
404         for (i = 0; i < sizeof(config->skc_fsname) &&
405              config->skc_fsname[i] != '\0';  i++)
406                 ; /* empty loop */
407         if (i == sizeof(config->skc_fsname)) {
408                 printerr(0, "File system name not null terminated\n");
409                 return -1;
410         }
411
412         for (i = 0; i < sizeof(config->skc_nodemap) &&
413              config->skc_nodemap[i] != '\0';  i++)
414                 ; /* empty loop */
415         if (i == sizeof(config->skc_nodemap)) {
416                 printerr(0, "Nodemap name not null terminated\n");
417                 return -1;
418         }
419
420         if (config->skc_type == SK_TYPE_INVALID) {
421                 printerr(0, "Invalid key type\n");
422                 return -1;
423         }
424
425         return 0;
426 }
427
428 /**
429  * Hashes \a string and places the hash in \a hash
430  * at \a hash
431  *
432  * \param[in]           string          Null terminated string to hash
433  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
434  * \param[in,out]       hash            gss_buffer_desc to hold the result
435  *
436  * \return      -1      failure
437  * \return      0       success
438  */
439 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
440                           gss_buffer_desc *hash)
441 {
442         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
443         size_t len = strlen(string);
444         unsigned int hashlen;
445
446         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
447                 goto out_err;
448         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
449                 goto out_err;
450         if (!EVP_DigestUpdate(ctx, string, len))
451                 goto out_err;
452         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
453                 goto out_err;
454
455         EVP_MD_CTX_destroy(ctx);
456         hash->length = hashlen;
457         return 0;
458
459 out_err:
460         EVP_MD_CTX_destroy(ctx);
461         return -1;
462 }
463
464 /**
465  * Hashes \a string and verifies the resulting hash matches the value
466  * in \a current_hash
467  *
468  * \param[in]           string          Null terminated string to hash
469  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
470  * \param[in,out]       current_hash    gss_buffer_desc to compare to
471  *
472  * \return      gss error       failure
473  * \return      GSS_S_COMPLETE  success
474  */
475 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
476                         const gss_buffer_desc *current_hash)
477 {
478         gss_buffer_desc hash;
479         unsigned char hashbuf[EVP_MAX_MD_SIZE];
480
481         hash.value = hashbuf;
482         hash.length = sizeof(hashbuf);
483
484         if (sk_hash_string(string, hash_alg, &hash))
485                 return GSS_S_FAILURE;
486         if (current_hash->length != hash.length)
487                 return GSS_S_DEFECTIVE_TOKEN;
488         if (memcmp(current_hash->value, hash.value, hash.length))
489                 return GSS_S_BAD_SIG;
490
491         return GSS_S_COMPLETE;
492 }
493
494 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
495                                        const char *mgsnid)
496 {
497         lnet_nid_t nid;
498         int i;
499
500         nid = libcfs_str2nid(mgsnid);
501         if (nid == LNET_NID_ANY)
502                 return 0;
503
504         for (i = 0; i < MAX_MGSNIDS; i++)
505                 if  (config->skc_mgsnids[i] == nid)
506                         return 1;
507         return 0;
508 }
509
510 /**
511  * Create an sk_cred structure populated with initial configuration info and the
512  * key.  \a tgt and \a nodemap are used in determining the expected key
513  * description so the key can be found by searching the keyring.
514  * This is done because there is no easy way to pass keys from the mount command
515  * all the way to the request_key call.  In addition any keys can be dynamically
516  * added to the keyrings and still found.  The keyring that needs to be used
517  * must be the session keyring.
518  *
519  * \param[in]   tgt             Target file system
520  * \param[in]   nodemap         Cluster name for the key.  This correlates to
521  *                              the nodemap name and is used by the server side.
522  *                              For the client this will be NULL.
523  * \param[in]   flags           Flags for the credentials
524  *
525  * \return      sk_cred Allocated struct sk_cred on success
526  * \return      NULL    failure
527  */
528 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
529                                const uint32_t flags)
530 {
531         struct sk_keyfile_config *config;
532         struct sk_kernel_ctx *kctx;
533         struct sk_cred *skc = NULL;
534         char description[SK_DESCRIPTION_SIZE + 1];
535         char fsname[MTI_NAME_MAXLEN + 1];
536         const char *mgsnid = NULL;
537         char *ptr;
538         long sk_key;
539         int keylen;
540         int len;
541         int rc;
542
543         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
544                  tgt, nodemap);
545
546         memset(description, 0, sizeof(description));
547         memset(fsname, 0, sizeof(fsname));
548
549         /* extract the file system name from target */
550         ptr = index(tgt, '-');
551         if (!ptr) {
552                 len = strlen(tgt);
553
554                 /* This must be an MGC target */
555                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
556                         printerr(0, "Invalid target name\n");
557                         return NULL;
558                 }
559                 mgsnid = tgt + 3;
560         } else {
561                 len = ptr - tgt;
562         }
563
564         if (len > MTI_NAME_MAXLEN) {
565                 printerr(0, "Invalid target name\n");
566                 return NULL;
567         }
568         memcpy(fsname, tgt, len);
569
570         if (nodemap) {
571                 if (mgsnid)
572                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
573                                       "lustre:MGS:%s", nodemap);
574                 else
575                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
576                                       "lustre:%s:%s", fsname, nodemap);
577         } else {
578                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
579                               fsname);
580         }
581
582         if (rc >= SK_DESCRIPTION_SIZE) {
583                 printerr(0, "Invalid key description\n");
584                 return NULL;
585         }
586
587         /* It may be a good idea to move Lustre keys to the gss_keyring
588          * (lgssc) type so that they expire when Lustre modules are removed.
589          * Unfortunately it can't be done at mount time because the mount
590          * syscall could trigger the Lustre modules to load and until that
591          * point we don't have a lgssc key type.
592          *
593          * TODO: Query the community for a consensus here  */
594         printerr(2, "Searching for key with description: %s\n", description);
595         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
596                                description, 0);
597         if (sk_key == -1) {
598                 printerr(1, "No key found for %s\n", description);
599                 return NULL;
600         }
601
602         keylen = keyctl_read_alloc(sk_key, (void **)&config);
603         if (keylen == -1) {
604                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
605                          strerror(errno));
606                 return NULL;
607         } else if (keylen != sizeof(*config)) {
608                 printerr(0, "Unexpected key size: %d returned for key %ld, "
609                          "expected %zu bytes\n",
610                          keylen, sk_key, sizeof(*config));
611                 goto out_err;
612         }
613
614         sk_config_disk_to_cpu(config);
615
616         if (sk_validate_config(config)) {
617                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
618                 goto out_err;
619         }
620
621         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
622                 printerr(0, "Target name does not match key's MGS NIDs\n");
623                 goto out_err;
624         }
625
626         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
627                 printerr(0, "Target name does not match key's file system\n");
628                 goto out_err;
629         }
630
631         skc = malloc(sizeof(*skc));
632         if (!skc) {
633                 printerr(0, "Failed to allocate memory for sk_cred\n");
634                 goto out_err;
635         }
636
637         /* this initializes all gss_buffer_desc to empty as well */
638         memset(skc, 0, sizeof(*skc));
639
640         skc->sc_flags = flags;
641         skc->sc_tgt.length = strlen(tgt) + 1;
642         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
643         if (!skc->sc_tgt.value) {
644                 printerr(0, "Failed to allocate memory for target\n");
645                 goto out_err;
646         }
647         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
648
649         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
650         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
651         if (!skc->sc_nodemap_hash.value) {
652                 printerr(0, "Failed to allocate memory for nodemap hash\n");
653                 goto out_err;
654         }
655
656         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
657                            &skc->sc_nodemap_hash)) {
658                 printerr(0, "Failed to generate hash for nodemap name\n");
659                 goto out_err;
660         }
661
662         kctx = &skc->sc_kctx;
663         kctx->skc_version = config->skc_version;
664         kctx->skc_hmac_alg = config->skc_hmac_alg;
665         kctx->skc_crypt_alg = config->skc_crypt_alg;
666         kctx->skc_expire = config->skc_expire;
667
668         /* key payload format is in bits, convert to bytes */
669         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
670         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
671         if (!kctx->skc_shared_key.value) {
672                 printerr(0, "Failed to allocate memory for shared key\n");
673                 goto out_err;
674         }
675         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
676                kctx->skc_shared_key.length);
677
678         skc->sc_p.length = config->skc_prime_bits / 8;
679         skc->sc_p.value = malloc(skc->sc_p.length);
680         if (!skc->sc_p.value) {
681                 printerr(0, "Failed to allocate p\n");
682                 goto out_err;
683         }
684         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
685
686         free(config);
687
688         return skc;
689
690 out_err:
691         sk_free_cred(skc);
692
693         free(config);
694         return NULL;
695 }
696
697 /**
698  * Populates the DH parameters for the DHKE
699  *
700  * \param[in,out]       skc             Shared key credentials structure to
701  *                                      populate with DH parameters
702  *
703  * \retval      GSS_S_COMPLETE  success
704  * \retval      GSS_S_FAILURE   failure
705  */
706 uint32_t sk_gen_params(struct sk_cred *skc)
707 {
708         uint32_t random;
709         int rc;
710
711         /* Random value used by both the request and response as part of the
712          * key binding material.  This also should ensure we have unqiue
713          * tokens that are sent to the remote server which is important because
714          * the token is hashed for the sunrpc cache lookups and a failure there
715          * would cause connection attempts to fail indefinitely due to the large
716          * timeout value on the server side */
717         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
718                 printerr(0, "Failed to get data for random parameter: %s\n",
719                          ERR_error_string(ERR_get_error(), NULL));
720                 return GSS_S_FAILURE;
721         }
722
723         /* The random value will always be used in byte range operations
724          * so we keep it as big endian from this point on */
725         skc->sc_kctx.skc_host_random = random;
726
727         /* Populate DH parameters */
728         skc->sc_params = DH_new();
729         if (!skc->sc_params) {
730                 printerr(0, "Failed to allocate DH\n");
731                 return GSS_S_FAILURE;
732         }
733
734         skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
735         if (!skc->sc_params->p) {
736                 printerr(0, "Failed to convert binary to BIGNUM\n");
737                 return GSS_S_FAILURE;
738         }
739
740         /* We use a static generator for shared key */
741         skc->sc_params->g = BN_new();
742         if (!skc->sc_params->g) {
743                 printerr(0, "Failed to allocate new BIGNUM\n");
744                 return GSS_S_FAILURE;
745         }
746         if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
747                 printerr(0, "Failed to set g value for DH params\n");
748                 return GSS_S_FAILURE;
749         }
750
751         /* Verify that we have a safe prime and valid generator */
752         if (DH_check(skc->sc_params, &rc) != 1) {
753                 printerr(0, "DH_check() failed: %d\n", rc);
754                 return GSS_S_FAILURE;
755         } else if (rc) {
756                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
757                 return GSS_S_FAILURE;
758         }
759
760         if (DH_generate_key(skc->sc_params) != 1) {
761                 printerr(0, "Failed to generate public DH key: %s\n",
762                          ERR_error_string(ERR_get_error(), NULL));
763                 return GSS_S_FAILURE;
764         }
765
766         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
767         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
768         if (!skc->sc_pub_key.value) {
769                 printerr(0, "Failed to allocate memory for public key\n");
770                 return GSS_S_FAILURE;
771         }
772
773         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
774
775         return GSS_S_COMPLETE;
776 }
777
778 /**
779  * Convert SK hash algorithm into openssl message digest
780  *
781  * \param[in,out]       alg             SK hash algorithm
782  *
783  * \retval              EVP_MD
784  */
785 static inline const EVP_MD *sk_hash_to_evp_md(enum sk_hmac_alg alg)
786 {
787         switch (alg) {
788         case SK_HMAC_SHA256:
789                 return EVP_sha256();
790         case SK_HMAC_SHA512:
791                 return EVP_sha512();
792         default:
793                 return EVP_md_null();
794         }
795 }
796
797 /**
798  * Signs (via HMAC) the parameters used only in the key initialization protocol.
799  *
800  * \param[in]           key             Key to use for HMAC
801  * \param[in]           bufs            Array of gss_buffer_desc to generate
802  *                                      HMAC for
803  * \param[in]           numbufs         Number of buffers in array
804  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
805  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
806  *                                      in this gss_buffer_desc.  Caller must
807  *                                      free this.
808  *
809  * \retval      0       success
810  * \retval      -1      failure
811  */
812 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
813                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
814 {
815         HMAC_CTX hctx;
816         unsigned int hashlen = EVP_MD_size(hash_alg);
817         int i;
818         int rc = -1;
819
820         if (hash_alg == EVP_md_null()) {
821                 printerr(0, "Invalid hash algorithm\n");
822                 return -1;
823         }
824
825         HMAC_CTX_init(&hctx);
826
827         hmac->length = hashlen;
828         hmac->value = malloc(hashlen);
829         if (!hmac->value) {
830                 printerr(0, "Failed to allocate memory for HMAC\n");
831                 goto out;
832         }
833
834         if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
835                 printerr(0, "Failed to init HMAC\n");
836                 goto out;
837         }
838
839         for (i = 0; i < numbufs; i++) {
840                 if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
841                         printerr(0, "Failed to update HMAC\n");
842                         goto out;
843                 }
844         }
845
846         /* The result gets populated in hmac */
847         if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
848                 printerr(0, "Failed to finalize HMAC\n");
849                 goto out;
850         }
851
852         if (hmac->length != hashlen) {
853                 printerr(0, "HMAC size does not match expected\n");
854                 goto out;
855         }
856
857         rc = 0;
858 out:
859         HMAC_CTX_cleanup(&hctx);
860         return rc;
861 }
862
863 /**
864  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
865  * and verifies against \a hmac.
866  *
867  * \param[in]   skc             Shared key credentials
868  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
869  * \param[in]   numbufs         Number of buffers in array
870  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
871  * \param[in]   hmac            HMAC to verify against
872  *
873  * \retval      GSS_S_COMPLETE  success (match)
874  * \retval      gss error       failure
875  */
876 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
877                         const int numbufs, const EVP_MD *hash_alg,
878                         gss_buffer_desc *hmac)
879 {
880         gss_buffer_desc bufs_hmac;
881         int rc;
882
883         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
884                          &bufs_hmac)) {
885                 printerr(0, "Failed to sign buffers to verify HMAC\n");
886                 if (bufs_hmac.value)
887                         free(bufs_hmac.value);
888                 return GSS_S_FAILURE;
889         }
890
891         if (hmac->length != bufs_hmac.length) {
892                 printerr(0, "Invalid HMAC size\n");
893                 free(bufs_hmac.value);
894                 return GSS_S_BAD_SIG;
895         }
896
897         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
898         free(bufs_hmac.value);
899
900         if (rc)
901                 return GSS_S_BAD_SIG;
902
903         return GSS_S_COMPLETE;
904 }
905
906 /**
907  * Cleanup an sk_cred freeing any resources
908  *
909  * \param[in,out]       skc     Shared key credentials to free
910  */
911 void sk_free_cred(struct sk_cred *skc)
912 {
913         if (!skc)
914                 return;
915
916         if (skc->sc_p.value)
917                 free(skc->sc_p.value);
918         if (skc->sc_pub_key.value)
919                 free(skc->sc_pub_key.value);
920         if (skc->sc_tgt.value)
921                 free(skc->sc_tgt.value);
922         if (skc->sc_nodemap_hash.value)
923                 free(skc->sc_nodemap_hash.value);
924         if (skc->sc_hmac.value)
925                 free(skc->sc_hmac.value);
926
927         /* Overwrite keys and IV before freeing */
928         if (skc->sc_dh_shared_key.value) {
929                 memset(skc->sc_dh_shared_key.value, 0,
930                        skc->sc_dh_shared_key.length);
931                 free(skc->sc_dh_shared_key.value);
932         }
933         if (skc->sc_kctx.skc_hmac_key.value) {
934                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
935                        skc->sc_kctx.skc_hmac_key.length);
936                 free(skc->sc_kctx.skc_hmac_key.value);
937         }
938         if (skc->sc_kctx.skc_encrypt_key.value) {
939                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
940                        skc->sc_kctx.skc_encrypt_key.length);
941                 free(skc->sc_kctx.skc_encrypt_key.value);
942         }
943         if (skc->sc_kctx.skc_shared_key.value) {
944                 memset(skc->sc_kctx.skc_shared_key.value, 0,
945                        skc->sc_kctx.skc_shared_key.length);
946                 free(skc->sc_kctx.skc_shared_key.value);
947         }
948         if (skc->sc_kctx.skc_session_key.value) {
949                 memset(skc->sc_kctx.skc_session_key.value, 0,
950                        skc->sc_kctx.skc_session_key.length);
951                 free(skc->sc_kctx.skc_session_key.value);
952         }
953
954         if (skc->sc_params)
955                 DH_free(skc->sc_params);
956
957         free(skc);
958         skc = NULL;
959 }
960
961 /* This function handles key derivation using the hash algorithm specified in
962  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
963  * \a origin_key to produce a \a derived_key.  The first element of the
964  * key_binding_bufs array is reserved for the counter used in the KDF.  The
965  * derived key in \a derived_key could differ in size from \a origin_key and
966  * must be populated with the expected size and a valid buffer to hold the
967  * contents.
968  *
969  * If the derived key size is greater than the HMAC algorithm size it will be
970  * a done using several iterations of a counter and the key binding bufs.
971  *
972  * If the size is smaller it will take copy the first N bytes necessary to
973  * fill the derived key. */
974 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
975            gss_buffer_desc *key_binding_bufs, int numbufs, int hmac_alg)
976 {
977         size_t remain;
978         size_t bytes;
979         uint32_t counter;
980         char *keydata;
981         gss_buffer_desc tmp_hash;
982         int i;
983         int rc;
984
985         if (numbufs < 1)
986                 return -EINVAL;
987
988         /* Use a counter as the first buffer followed by the key binding
989          * buffers in the event we need more than one a single cycle to
990          * produced a symmetric key large enough in size */
991         key_binding_bufs[0].value = &counter;
992         key_binding_bufs[0].length = sizeof(counter);
993
994         remain = derived_key->length;
995         keydata = derived_key->value;
996         i = 0;
997         while (remain > 0) {
998                 counter = htobe32(i++);
999                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
1000                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1001                 if (rc) {
1002                         if (tmp_hash.value)
1003                                 free(tmp_hash.value);
1004                         return rc;
1005                 }
1006
1007                 LASSERT(sk_hmac_types[hmac_alg].sht_bytes ==
1008                         tmp_hash.length);
1009
1010                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1011                 memcpy(keydata, tmp_hash.value, bytes);
1012                 free(tmp_hash.value);
1013                 remain -= bytes;
1014                 keydata += bytes;
1015         }
1016
1017         return 0;
1018 }
1019
1020 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1021  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1022  * (Sep 2014) Section 5.5.1
1023  *
1024  * \param[in,out]       skc             Shared key credentials structure with
1025  *
1026  * \return      -1              failure
1027  * \return      0               success
1028  */
1029 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1030                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1031 {
1032         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1033         gss_buffer_desc *session_key = &kctx->skc_session_key;
1034         gss_buffer_desc bufs[5];
1035         int rc = -1;
1036
1037         session_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1038         session_key->value = malloc(session_key->length);
1039         if (!session_key->value) {
1040                 printerr(0, "Failed to allocate memory for session key\n");
1041                 return rc;
1042         }
1043
1044         /* Key binding info ordering
1045          * 1. Reserved for counter
1046          * 1. DH shared key
1047          * 2. Client's NIDs
1048          * 3. Client's token
1049          * 4. Server's token */
1050         bufs[0].value = NULL;
1051         bufs[0].length = 0;
1052         bufs[1] = skc->sc_dh_shared_key;
1053         bufs[2].value = &client_nid;
1054         bufs[2].length = sizeof(client_nid);
1055         bufs[3] = *client_token;
1056         bufs[4] = *server_token;
1057
1058         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1059                       5, kctx->skc_hmac_alg);
1060 }
1061
1062 /* Uses the session key to create an HMAC key and encryption key.  In
1063  * integrity mode the session key used to generate the HMAC key uses
1064  * session information which is available on the wire but by creating
1065  * a session based HMAC key we can prevent potential replay as both the
1066  * client and server have random numbers used as part of the key creation.
1067  *
1068  * The keys used for integrity and privacy are formulated as below using
1069  * the session key that is the output of the key derivation function.  The
1070  * HMAC algorithm is determined by the shared key algorithm selected in the
1071  * key file.
1072  *
1073  * For ski mode:
1074  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1075  *
1076  * For skpi mode:
1077  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1078  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1079  *
1080  * \param[in,out]       skc             Shared key credentials structure with
1081  *
1082  * \return      -1              failure
1083  * \return      0               success
1084  */
1085 int sk_compute_keys(struct sk_cred *skc)
1086 {
1087         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1088         gss_buffer_desc *session_key = &kctx->skc_session_key;
1089         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1090         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1091         char *encrypt = "Encrypt";
1092         char *integrity = "Integrity";
1093         int rc;
1094
1095         hmac_key->length = sk_hmac_types[kctx->skc_hmac_alg].sht_bytes;
1096         hmac_key->value = malloc(hmac_key->length);
1097         if (!hmac_key->value)
1098                 return -ENOMEM;
1099
1100         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1101                                session_key->length, SK_PBKDF2_ITERATIONS,
1102                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1103                                hmac_key->length, hmac_key->value);
1104         if (rc == 0)
1105                 return -EINVAL;
1106
1107         /* Encryption key is only populated in privacy mode */
1108         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1109                 return 0;
1110
1111         encrypt_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1112         encrypt_key->value = malloc(encrypt_key->length);
1113         if (!encrypt_key->value)
1114                 return -ENOMEM;
1115
1116         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1117                                session_key->length, SK_PBKDF2_ITERATIONS,
1118                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1119                                encrypt_key->length, encrypt_key->value);
1120         if (rc == 0)
1121                 return -EINVAL;
1122
1123         return 0;
1124 }
1125
1126 /**
1127  * Computes a session key based on the DH parameters from the host and its peer
1128  *
1129  * \param[in,out]       skc             Shared key credentials structure with
1130  *                                      the session key populated with the
1131  *                                      compute key
1132  * \param[in]           pub_key         Public key returned from peer in
1133  *                                      gss_buffer_desc
1134  * \return      gss error               failure
1135  * \return      GSS_S_COMPLETE          success
1136  */
1137 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1138 {
1139         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1140         BIGNUM *remote_pub_key;
1141         int status;
1142         uint32_t rc = GSS_S_FAILURE;
1143
1144         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1145         if (!remote_pub_key) {
1146                 printerr(0, "Failed to convert binary to BIGNUM\n");
1147                 return rc;
1148         }
1149
1150         dh_shared->length = DH_size(skc->sc_params);
1151         dh_shared->value = malloc(dh_shared->length);
1152         if (!dh_shared->value) {
1153                 printerr(0, "Failed to allocate memory for computed shared "
1154                          "secret key\n");
1155                 goto out_err;
1156         }
1157
1158         /* This compute the shared key from the DHKE */
1159         status = DH_compute_key(dh_shared->value, remote_pub_key,
1160                                 skc->sc_params);
1161         if (status == -1) {
1162                 printerr(0, "DH_compute_key() failed: %s\n",
1163                          ERR_error_string(ERR_get_error(), NULL));
1164                 goto out_err;
1165         } else if (status < dh_shared->length) {
1166                 printerr(0, "DH_compute_key() returned a short key of %d "
1167                          "bytes, expected: %zu\n", status, dh_shared->length);
1168                 rc = GSS_S_DEFECTIVE_TOKEN;
1169                 goto out_err;
1170         }
1171
1172         rc = GSS_S_COMPLETE;
1173
1174 out_err:
1175         BN_free(remote_pub_key);
1176         return rc;
1177 }
1178
1179 /**
1180  * Creates a serialized buffer for the kernel in the order of struct
1181  * sk_kernel_ctx.
1182  *
1183  * \param[in,out]       skc             Shared key credentials structure
1184  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1185  *                                      Caller must free this buffer.
1186  *
1187  * \return      0       success
1188  * \return      -1      failure
1189  */
1190 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1191 {
1192         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1193         char *p, *end;
1194         size_t bufsize;
1195
1196         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1197                   kctx->skc_encrypt_key.length;
1198
1199         ctx_token->value = malloc(bufsize);
1200         if (!ctx_token->value)
1201                 return -1;
1202         ctx_token->length = bufsize;
1203
1204         p = ctx_token->value;
1205         end = p + ctx_token->length;
1206
1207         if (WRITE_BYTES(&p, end, kctx->skc_version))
1208                 return -1;
1209         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1210                 return -1;
1211         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1212                 return -1;
1213         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1214                 return -1;
1215         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1216                 return -1;
1217         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1218                 return -1;
1219         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1220                 return -1;
1221         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1222                 return -1;
1223
1224         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1225
1226         return 0;
1227 }
1228
1229 /**
1230  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1231  * up to \a numbufs.  Memory is allocated for each value and length
1232  * will be populated with the length
1233  *
1234  * \param[in,out]       bufs    Array of gss_buffer_descs
1235  * \param[in,out]       numbufs number of gss_buffer_desc in array
1236  * \param[in]           ns      netstring to decode
1237  *
1238  * \return      buffers populated       success
1239  * \return      -1                      failure
1240  */
1241 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1242 {
1243         char *ptr = ns->value;
1244         size_t remain = ns->length;
1245         unsigned int size;
1246         int digits;
1247         int sep;
1248         int rc;
1249         int i;
1250
1251         for (i = 0; i < numbufs; i++) {
1252                 /* read the size of first buffer */
1253                 rc = sscanf(ptr, "%9u", &size);
1254                 if (rc < 1)
1255                         goto out_err;
1256                 digits = (size) ? ceil(log10(size + 1)) : 1;
1257
1258                 /* sep of current string */
1259                 sep = size + digits + 2;
1260
1261                 /* check to make sure it's valid */
1262                 if (remain < sep || ptr[digits] != ':' ||
1263                     ptr[sep - 1] != ',')
1264                         goto out_err;
1265
1266                 bufs[i].length = size;
1267                 if (size == 0) {
1268                         bufs[i].value = NULL;
1269                 } else {
1270                         bufs[i].value = malloc(size);
1271                         if (!bufs[i].value)
1272                                 goto out_err;
1273                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1274                 }
1275
1276                 remain -= sep;
1277                 ptr += sep;
1278         }
1279
1280         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1281         return i;
1282
1283 out_err:
1284         while (i-- > 0) {
1285                 if (bufs[i].value)
1286                         free(bufs[i].value);
1287                 bufs[i].length = 0;
1288         }
1289         return -1;
1290 }
1291
1292 /**
1293  * Creates a netstring in a gss_buffer_desc that consists of all
1294  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1295  * as binary as it can contain null characters.
1296  *
1297  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1298  * \param[in]           numbufs         Number of buffers in array
1299  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1300  *                                      netstring
1301  *
1302  * \return      -1      failure
1303  * \return      0       success
1304  */
1305 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1306                         gss_buffer_desc *ns)
1307 {
1308         unsigned char *ptr;
1309         int size = 0;
1310         int rc;
1311         int i;
1312
1313         /* size of string in decimal, string size, colon, and comma */
1314         for (i = 0; i < numbufs; i++) {
1315
1316                 if (bufs[i].length == 0)
1317                         size += 3;
1318                 else
1319                         size += ceil(log10(bufs[i].length + 1)) +
1320                                 bufs[i].length + 2;
1321         }
1322
1323         ns->length = size;
1324         ns->value = malloc(ns->length);
1325         if (!ns->value) {
1326                 ns->length = 0;
1327                 return -1;
1328         }
1329
1330         ptr = ns->value;
1331         for (i = 0; i < numbufs; i++) {
1332                 /* size */
1333                 rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
1334                 ptr += rc;
1335
1336                 /* contents */
1337                 memcpy(ptr, bufs[i].value, bufs[i].length);
1338                 ptr += bufs[i].length;
1339
1340                 /* delimeter */
1341                 *ptr++ = ',';
1342
1343                 size -= bufs[i].length + rc + 1;
1344
1345                 /* should not happen */
1346                 if (size < 0)
1347                         abort();
1348         }
1349
1350         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1351         return 0;
1352 }