Whamcloud - gitweb
LU-6245 libcfs: remove byteorder.h
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Author: Jeremy Filizetti <jfilizet@iu.edu>
26  */
27
28 #include <fcntl.h>
29 #include <limits.h>
30 #include <math.h>
31 #include <string.h>
32 #include <stdbool.h>
33 #include <unistd.h>
34 #include <openssl/dh.h>
35 #include <openssl/engine.h>
36 #include <openssl/err.h>
37 #include <openssl/hmac.h>
38 #include <sys/types.h>
39 #include <sys/stat.h>
40 #include <lnet/nidstr.h>
41
42 #include "sk_utils.h"
43 #include "write_bytes.h"
44
45 static struct sk_crypt_type sk_crypt_types[] = {
46         [SK_CRYPT_AES256_CTR] = {
47                 .sct_name = "ctr(aes)",
48                 .sct_bytes = 32,
49         },
50 };
51
52 static struct sk_hmac_type sk_hmac_types[] = {
53         [SK_HMAC_SHA256] = {
54                 .sht_name = "hmac(sha256)",
55                 .sht_bytes = 32,
56         },
57         [SK_HMAC_SHA512] = {
58                 .sht_name = "hmac(sha512)",
59                 .sht_bytes = 64,
60         },
61 };
62
63 #ifdef _NEW_BUILD_
64 # include "lgss_utils.h"
65 #else
66 # include "gss_util.h"
67 # include "gss_oids.h"
68 # include "err_util.h"
69 #endif
70
71 #ifdef _ERR_UTIL_H_
72 /**
73  * Initializes logging
74  * \param[in]   program         Program name to output
75  * \param[in]   verbose         Verbose flag
76  * \param[in]   fg              Whether or not to run in foreground
77  *
78  */
79 void sk_init_logging(char *program, int verbose, int fg)
80 {
81         initerr(program, verbose, fg);
82 }
83 #endif
84
85 /**
86  * Loads the key from \a filename and returns the struct sk_keyfile_config.
87  * It should be freed by the caller.
88  *
89  * \param[in]   filename                Disk or key payload data
90  *
91  * \return      sk_keyfile_config       sucess
92  * \return      NULL                    failure
93  */
94 struct sk_keyfile_config *sk_read_file(char *filename)
95 {
96         struct sk_keyfile_config *config;
97         char *ptr;
98         size_t rc;
99         size_t remain;
100         int fd;
101
102         config = malloc(sizeof(*config));
103         if (!config) {
104                 printerr(0, "Failed to allocate memory for config\n");
105                 return NULL;
106         }
107
108         /* allow standard input override */
109         if (strcmp(filename, "-") == 0)
110                 fd = dup(STDIN_FILENO);
111         else
112                 fd = open(filename, O_RDONLY);
113
114         if (fd == -1) {
115                 printerr(0, "Error opening file %s: %s\n", filename,
116                          strerror(errno));
117                 goto out_free;
118         }
119
120         ptr = (char *)config;
121         remain = sizeof(*config);
122         while (remain > 0) {
123                 rc = read(fd, ptr, remain);
124                 if (rc == -1) {
125                         if (errno == EINTR)
126                                 continue;
127                         printerr(0, "read() failed on %s: %s\n", filename,
128                                  strerror(errno));
129                         goto out_close;
130                 } else if (rc == 0) {
131                         printerr(0, "File %s does not have a complete key\n",
132                                  filename);
133                         goto out_close;
134                 }
135                 ptr += rc;
136                 remain -= rc;
137         }
138
139         close(fd);
140         sk_config_disk_to_cpu(config);
141         return config;
142
143 out_close:
144         close(fd);
145 out_free:
146         free(config);
147         return NULL;
148 }
149
150 /**
151  * Checks if a key matching \a description is found in the keyring for
152  * logging purposes and then attempts to load \a payload of \a psize into a key
153  * with \a description.
154  *
155  * \param[in]   payload         Key payload
156  * \param[in]   psize           Payload size
157  * \param[in]   description     Description used for key in keyring
158  *
159  * \return      0       sucess
160  * \return      -1      failure
161  */
162 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
163                                 const char *description)
164 {
165         struct sk_keyfile_config payload;
166         key_serial_t key;
167
168         memcpy(&payload, skc, sizeof(*skc));
169
170         /* In the keyring use the disk layout so keyctl pipe can be used */
171         sk_config_cpu_to_disk(&payload);
172
173         /* Check to see if a key is already loaded matching description */
174         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
175         if (key != -1)
176                 printerr(2, "Key %d found in session keyring, replacing\n",
177                          key);
178
179         key = add_key("user", description, &payload, sizeof(payload),
180                       KEY_SPEC_USER_KEYRING);
181         if (key != -1)
182                 printerr(2, "Added key %d with description %s\n", key,
183                          description);
184         else
185                 printerr(0, "Failed to add key with %s\n", description);
186
187         return key;
188 }
189
190 /**
191  * Reads the key from \a path, verifies it and loads into the session keyring
192  * using a description determined by the the \a type.  Existing keys with the
193  * same description are replaced.
194  *
195  * \param[in]   path    Path to key file
196  * \param[in]   type    Type of key to load which determines the description
197  *
198  * \return      0       sucess
199  * \return      -1      failure
200  */
201 int sk_load_keyfile(char *path, int type)
202 {
203         struct sk_keyfile_config *config;
204         char description[SK_DESCRIPTION_SIZE + 1];
205         struct stat buf;
206         int i;
207         int rc;
208         int rc2 = -1;
209
210         rc = stat(path, &buf);
211         if (rc == -1) {
212                 printerr(0, "stat() failed for file %s: %s\n", path,
213                          strerror(errno));
214                 return rc2;
215         }
216
217         config = sk_read_file(path);
218         if (!config)
219                 return rc2;
220
221         /* Similar to ssh, require adequate care of key files */
222         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
223                 printerr(0, "Shared key files must be read/writeable only by "
224                          "owner\n");
225                 return -1;
226         }
227
228         if (sk_validate_config(config))
229                 goto out;
230
231         /* The server side can have multiple key files per file system so
232          * the nodemap name is appended to the key description to uniquely
233          * identify it */
234         if (type & SK_TYPE_MGS) {
235                 /* Any key can be an MGS key as long as we are told to use it */
236                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
237                               config->skc_nodemap);
238                 if (rc >= SK_DESCRIPTION_SIZE)
239                         goto out;
240                 if (sk_load_key(config, description) == -1)
241                         goto out;
242         }
243         if (type & SK_TYPE_SERVER) {
244                 /* Server keys need to have the file system name in the key */
245                 if (!config->skc_fsname) {
246                         printerr(0, "Key configuration has no file system "
247                                  "attribute.  Can't load as server type\n");
248                         goto out;
249                 }
250                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
251                               config->skc_fsname, config->skc_nodemap);
252                 if (rc >= SK_DESCRIPTION_SIZE)
253                         goto out;
254                 if (sk_load_key(config, description) == -1)
255                         goto out;
256         }
257         if (type & SK_TYPE_CLIENT) {
258                 /* Load client file system key */
259                 if (config->skc_fsname) {
260                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
261                                       "lustre:%s", config->skc_fsname);
262                         if (rc >= SK_DESCRIPTION_SIZE)
263                                 goto out;
264                         if (sk_load_key(config, description) == -1)
265                                 goto out;
266                 }
267
268                 /* Load client MGC keys */
269                 for (i = 0; i < MAX_MGSNIDS; i++) {
270                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
271                                 continue;
272                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
273                                       "lustre:MGC%s",
274                                       libcfs_nid2str(config->skc_mgsnids[i]));
275                         if (rc >= SK_DESCRIPTION_SIZE)
276                                 goto out;
277                         if (sk_load_key(config, description) == -1)
278                                 goto out;
279                 }
280         }
281
282         rc2 = 0;
283
284 out:
285         free(config);
286         return rc2;
287 }
288
289 /**
290  * Byte swaps config from cpu format to disk
291  *
292  * \param[in,out]       config          sk_keyfile_config to swap
293  */
294 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
295 {
296         int i;
297
298         if (!config)
299                 return;
300
301         config->skc_version = htobe32(config->skc_version);
302         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
303         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
304         config->skc_expire = htobe32(config->skc_expire);
305         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
306         config->skc_session_keylen = htobe32(config->skc_session_keylen);
307
308         for (i = 0; i < MAX_MGSNIDS; i++)
309                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
310
311         return;
312 }
313
314 /**
315  * Byte swaps config from disk format to cpu
316  *
317  * \param[in,out]       config          sk_keyfile_config to swap
318  */
319 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
320 {
321         int i;
322
323         if (!config)
324                 return;
325
326         config->skc_version = be32toh(config->skc_version);
327         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
328         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
329         config->skc_expire = be32toh(config->skc_expire);
330         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
331         config->skc_session_keylen = be32toh(config->skc_session_keylen);
332
333         for (i = 0; i < MAX_MGSNIDS; i++)
334                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
335
336         return;
337 }
338
339 /**
340  * Verifies the on key payload format is valid
341  *
342  * \param[in]   config          sk_keyfile_config
343  *
344  * \return      -1      failure
345  * \return      0       success
346  */
347 int sk_validate_config(const struct sk_keyfile_config *config)
348 {
349         int i;
350
351         if (!config) {
352                 printerr(0, "Null configuration passed\n");
353                 return -1;
354         }
355         if (config->skc_version != SK_CONF_VERSION) {
356                 printerr(0, "Invalid version\n");
357                 return -1;
358         }
359         if (config->skc_hmac_alg >= SK_HMAC_MAX) {
360                 printerr(0, "Invalid HMAC algorithm\n");
361                 return -1;
362         }
363         if (config->skc_crypt_alg >= SK_CRYPT_MAX) {
364                 printerr(0, "Invalid crypt algorithm\n");
365                 return -1;
366         }
367         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
368                 /* Try to limit key expiration to some reasonable minimum and
369                  * also prevent values over INT_MAX because there appears
370                  * to be a type conversion issue */
371                 printerr(0, "Invalid expiration time should be between %d "
372                          "and %d\n", 60, INT_MAX);
373                 return -1;
374         }
375         if (config->skc_session_keylen % 8 != 0 ||
376             config->skc_session_keylen > SK_SESSION_MAX_KEYLEN_BYTES * 8) {
377                 printerr(0, "Invalid session key length must be a multiple of 8"
378                          " and less then %d bits\n",
379                          SK_SESSION_MAX_KEYLEN_BYTES * 8);
380                 return -1;
381         }
382         if (config->skc_shared_keylen % 8 != 0 ||
383             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
384                 printerr(0, "Invalid shared key max length must be a multiple "
385                          "of 8 and less then %d bits\n",
386                          SK_MAX_KEYLEN_BYTES * 8);
387                 return -1;
388         }
389
390         /* Check for terminating nulls on strings */
391         for (i = 0; i < sizeof(config->skc_fsname) &&
392              config->skc_fsname[i] != '\0';  i++)
393                 ; /* empty loop */
394         if (i == sizeof(config->skc_fsname)) {
395                 printerr(0, "File system name not null terminated\n");
396                 return -1;
397         }
398
399         for (i = 0; i < sizeof(config->skc_nodemap) &&
400              config->skc_nodemap[i] != '\0';  i++)
401                 ; /* empty loop */
402         if (i == sizeof(config->skc_nodemap)) {
403                 printerr(0, "Nodemap name not null terminated\n");
404                 return -1;
405         }
406
407         return 0;
408 }
409
410 /**
411  * Hashes \a string and places the hash in \a hash
412  * at \a hash
413  *
414  * \param[in]           string          Null terminated string to hash
415  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
416  * \param[in,out]       hash            gss_buffer_desc to hold the result
417  *
418  * \return      -1      failure
419  * \return      0       success
420  */
421 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
422                           gss_buffer_desc *hash)
423 {
424         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
425         size_t len = strlen(string);
426         unsigned int hashlen;
427
428         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
429                 goto out_err;
430         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
431                 goto out_err;
432         if (!EVP_DigestUpdate(ctx, string, len))
433                 goto out_err;
434         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
435                 goto out_err;
436
437         EVP_MD_CTX_destroy(ctx);
438         hash->length = hashlen;
439         return 0;
440
441 out_err:
442         EVP_MD_CTX_destroy(ctx);
443         return -1;
444 }
445
446 /**
447  * Hashes \a string and verifies the resulting hash matches the value
448  * in \a current_hash
449  *
450  * \param[in]           string          Null terminated string to hash
451  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
452  * \param[in,out]       current_hash    gss_buffer_desc to compare to
453  *
454  * \return      gss error       failure
455  * \return      GSS_S_COMPLETE  success
456  */
457 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
458                         const gss_buffer_desc *current_hash)
459 {
460         gss_buffer_desc hash;
461         unsigned char hashbuf[EVP_MAX_MD_SIZE];
462
463         hash.value = hashbuf;
464         hash.length = sizeof(hashbuf);
465
466         if (sk_hash_string(string, hash_alg, &hash))
467                 return GSS_S_FAILURE;
468         if (current_hash->length != hash.length)
469                 return GSS_S_DEFECTIVE_TOKEN;
470         if (memcmp(current_hash->value, hash.value, hash.length))
471                 return GSS_S_BAD_SIG;
472
473         return GSS_S_COMPLETE;
474 }
475
476 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
477                                        const char *mgsnid)
478 {
479         lnet_nid_t nid;
480         int i;
481
482         nid = libcfs_str2nid(mgsnid);
483         if (nid == LNET_NID_ANY)
484                 return 0;
485
486         for (i = 0; i < MAX_MGSNIDS; i++)
487                 if  (config->skc_mgsnids[i] == nid)
488                         return 1;
489         return 0;
490 }
491
492 /**
493  * Create an sk_cred structure populated with initial configuration info and the
494  * key.  \a tgt and \a nodemap are used in determining the expected key
495  * description so the key can be found by searching the keyring.
496  * This is done because there is no easy way to pass keys from the mount command
497  * all the way to the request_key call.  In addition any keys can be dynamically
498  * added to the keyrings and still found.  The keyring that needs to be used
499  * must be the session keyring.
500  *
501  * \param[in]   tgt             Target file system
502  * \param[in]   nodemap         Cluster name for the key.  This correlates to
503  *                              the nodemap name and is used by the server side.
504  *                              For the client this will be NULL.
505  * \param[in]   flags           Flags for the credentials
506  *
507  * \return      sk_cred Allocated struct sk_cred on success
508  * \return      NULL    failure
509  */
510 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
511                                const uint32_t flags)
512 {
513         struct sk_keyfile_config *config;
514         struct sk_kernel_ctx *kctx;
515         struct sk_cred *skc = NULL;
516         char description[SK_DESCRIPTION_SIZE + 1];
517         char fsname[MTI_NAME_MAXLEN + 1];
518         const char *mgsnid = NULL;
519         char *ptr;
520         long sk_key;
521         int keylen;
522         int len;
523         int rc;
524
525         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
526                  tgt, nodemap);
527
528         memset(description, 0, sizeof(description));
529         memset(fsname, 0, sizeof(fsname));
530
531         /* extract the file system name from target */
532         ptr = index(tgt, '-');
533         if (!ptr) {
534                 len = strlen(tgt);
535
536                 /* This must be an MGC target */
537                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
538                         printerr(0, "Invalid target name\n");
539                         return NULL;
540                 }
541                 mgsnid = tgt + 3;
542         } else {
543                 len = ptr - tgt;
544         }
545
546         if (len > MTI_NAME_MAXLEN) {
547                 printerr(0, "Invalid target name\n");
548                 return NULL;
549         }
550         memcpy(fsname, tgt, len);
551
552         if (nodemap) {
553                 if (mgsnid)
554                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
555                                       "lustre:MGS:%s", nodemap);
556                 else
557                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
558                                       "lustre:%s:%s", fsname, nodemap);
559         } else {
560                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
561                               fsname);
562         }
563
564         if (rc >= SK_DESCRIPTION_SIZE) {
565                 printerr(0, "Invalid key description\n");
566                 return NULL;
567         }
568
569         /* It may be a good idea to move Lustre keys to the gss_keyring
570          * (lgssc) type so that they expire when Lustre modules are removed.
571          * Unfortunately it can't be done at mount time because the mount
572          * syscall could trigger the Lustre modules to load and until that
573          * point we don't have a lgssc key type.
574          *
575          * TODO: Query the community for a consensus here  */
576         printerr(2, "Searching for key with description: %s\n", description);
577         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
578                                description, 0);
579         if (sk_key == -1) {
580                 printerr(1, "No key found for %s\n", description);
581                 return NULL;
582         }
583
584         keylen = keyctl_read_alloc(sk_key, (void **)&config);
585         if (keylen == -1) {
586                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
587                          strerror(errno));
588                 return NULL;
589         } else if (keylen != sizeof(*config)) {
590                 printerr(0, "Unexpected key size: %d returned for key %ld, "
591                          "expected %zu bytes\n",
592                          keylen, sk_key, sizeof(*config));
593                 goto out_err;
594         }
595
596         sk_config_disk_to_cpu(config);
597
598         if (sk_validate_config(config)) {
599                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
600                 goto out_err;
601         }
602
603         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
604                 printerr(0, "Target name does not match key's MGS NIDs\n");
605                 goto out_err;
606         }
607
608         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
609                 printerr(0, "Target name does not match key's file system\n");
610                 goto out_err;
611         }
612
613         skc = malloc(sizeof(*skc));
614         if (!skc) {
615                 printerr(0, "Failed to allocate memory for sk_cred\n");
616                 goto out_err;
617         }
618
619         /* this initializes all gss_buffer_desc to empty as well */
620         memset(skc, 0, sizeof(*skc));
621
622         skc->sc_flags = flags;
623         skc->sc_tgt.length = strlen(tgt) + 1;
624         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
625         if (!skc->sc_tgt.value) {
626                 printerr(0, "Failed to allocate memory for target\n");
627                 goto out_err;
628         }
629         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
630
631         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
632         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
633         if (!skc->sc_nodemap_hash.value) {
634                 printerr(0, "Failed to allocate memory for nodemap hash\n");
635                 goto out_err;
636         }
637
638         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
639                            &skc->sc_nodemap_hash)) {
640                 printerr(0, "Failed to generate hash for nodemap name\n");
641                 goto out_err;
642         }
643
644         kctx = &skc->sc_kctx;
645         kctx->skc_version = config->skc_version;
646         kctx->skc_hmac_alg = config->skc_hmac_alg;
647         kctx->skc_crypt_alg = config->skc_crypt_alg;
648         kctx->skc_expire = config->skc_expire;
649
650         /* key payload format is in bits, convert to bytes */
651         skc->sc_session_keylen = config->skc_session_keylen / 8;
652         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
653         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
654         if (!kctx->skc_shared_key.value) {
655                 printerr(0, "Failed to allocate memory for shared key\n");
656                 goto out_err;
657         }
658         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
659                kctx->skc_shared_key.length);
660
661         free(config);
662
663         return skc;
664
665 out_err:
666         if (skc)
667                 sk_free_cred(skc);
668
669         free(config);
670         return NULL;
671 }
672
673 /**
674  * Generates a public key and computes the private key for the DH key exchange.
675  * The parameters must be populated with the p and g from the peer.
676  *
677  * \param[in,out]       skc     Shared key credentials structure to populate
678  *                              with DH parameters
679  *
680  * \retval      GSS_S_COMPLETE  success
681  * \retval      GSS_S_FAILURE   failure
682  */
683 static uint32_t sk_gen_responder_params(struct sk_cred *skc)
684 {
685         int rc;
686
687         /* No keys to generate without privacy mode */
688         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
689                 return GSS_S_COMPLETE;
690
691         skc->sc_params = DH_new();
692         if (!skc->sc_params) {
693                 printerr(0, "Failed to allocate DH\n");
694                 return GSS_S_FAILURE;
695         }
696
697         /* responder should already have sc_p populated */
698         skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
699         if (!skc->sc_params->p) {
700                 printerr(0, "Failed to convert binary to BIGNUM\n");
701                 return GSS_S_FAILURE;
702         }
703
704         /* and we use a static generator for shared key */
705         skc->sc_params->g = BN_new();
706         if (!skc->sc_params->g) {
707                 printerr(0, "Failed to allocate new BIGNUM\n");
708                 return GSS_S_FAILURE;
709         }
710         if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
711                 printerr(0, "Failed to set g value for DH params\n");
712                 return GSS_S_FAILURE;
713         }
714
715         /* verify that we have a safe prime and valid generator */
716         if (DH_check(skc->sc_params, &rc) != 1) {
717                 printerr(0, "DH_check() failed: %d\n", rc);
718                 return GSS_S_FAILURE;
719         } else if (rc) {
720                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
721                 return GSS_S_FAILURE;
722         }
723
724         if (DH_generate_key(skc->sc_params) != 1) {
725                 printerr(0, "Failed to generate public DH key: %s\n",
726                          ERR_error_string(ERR_get_error(), NULL));
727                 return GSS_S_FAILURE;
728         }
729
730         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
731         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
732         if (!skc->sc_pub_key.value) {
733                 printerr(0, "Failed to allocate memory for public key\n");
734                 return GSS_S_FAILURE;
735         }
736
737         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
738
739         return GSS_S_COMPLETE;
740 }
741
742 static void sk_free_parameters(struct sk_cred *skc)
743 {
744         if (skc->sc_params)
745                 DH_free(skc->sc_params);
746         if (skc->sc_p.value)
747                 free(skc->sc_p.value);
748         if (skc->sc_pub_key.value)
749                 free(skc->sc_pub_key.value);
750
751         skc->sc_p.value = NULL;
752         skc->sc_p.length = 0;
753         skc->sc_pub_key.value = NULL;
754         skc->sc_pub_key.length = 0;
755 }
756
757 /**
758  * Generates shared key Diffie Hellman parameters used for the DH key exchange
759  * between host and peer if privacy mode is enabled
760  *
761  * \param[in,out]       skc     Shared key credentials structure to populate
762  *                              with DH parameters
763  *
764  * \retval      GSS_S_COMPLETE  success
765  * \retval      GSS_S_FAILURE   failure
766  */
767 static uint32_t sk_gen_initiator_params(struct sk_cred *skc)
768 {
769         gss_buffer_desc *iv = &skc->sc_kctx.skc_iv;
770         int rc;
771
772         /* The credential could be used so free existing parameters */
773         sk_free_parameters(skc);
774
775         /* Pseudo random should be sufficient here because the IV will be used
776          * with a key that is used only once.  This also should ensure we have
777          * unqiue tokens that are sent to the remote server which is important
778          * because the token is hashed for the sunrpc cache lookups and a
779          * failure there would cause connection attempts to fail indefinitely
780          * due to the large timeout value on the server side sunrpc cache
781          * (INT_MAX) */
782         iv->length = SK_IV_SIZE;
783         iv->value = malloc(iv->length);
784         if (!iv->value) {
785                 printerr(0, "Failed to allocate memory for IV\n");
786                 return GSS_S_FAILURE;
787         }
788         memset(iv->value, 0, iv->length);
789         if (RAND_bytes(iv->value, iv->length) != 1) {
790                 printerr(0, "Failed to get data for IV\n");
791                 return GSS_S_FAILURE;
792         }
793
794         /* Only privacy mode needs the rest of the parameter generation
795          * but we use IV in other modes as well so tokens should be
796          * unique */
797         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
798                 return GSS_S_COMPLETE;
799
800         skc->sc_params = DH_generate_parameters(skc->sc_session_keylen * 8,
801                                                 SK_GENERATOR, NULL, NULL);
802         if (skc->sc_params == NULL) {
803                 printerr(0, "Failed to generate diffie-hellman parameters: %s",
804                          ERR_error_string(ERR_get_error(), NULL));
805                 return GSS_S_FAILURE;
806         }
807
808         if (DH_check(skc->sc_params, &rc) != 1) {
809                 printerr(0, "DH_check() failed: %d\n", rc);
810                 return GSS_S_FAILURE;
811         } else if (rc) {
812                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
813                 return GSS_S_FAILURE;
814         }
815
816         if (DH_generate_key(skc->sc_params) != 1) {
817                 printerr(0, "Failed to generate public DH key: %s\n",
818                          ERR_error_string(ERR_get_error(), NULL));
819                 return GSS_S_FAILURE;
820         }
821
822         skc->sc_p.length = BN_num_bytes(skc->sc_params->p);
823         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
824         skc->sc_p.value = malloc(skc->sc_p.length);
825         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
826         if (!skc->sc_p.value || !skc->sc_pub_key.value) {
827                 printerr(0, "Failed to allocate memory for params\n");
828                 return GSS_S_FAILURE;
829         }
830
831         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
832         BN_bn2bin(skc->sc_params->p, skc->sc_p.value);
833
834         return GSS_S_COMPLETE;
835 }
836
837 /**
838  * Generates or populates the DH parameters depending on whether the system is
839  * the initiator or responder for the connection
840  *
841  * \param[in,out]       skc             Shared key credentials structure to
842  *                                      populate with DH parameters
843  * \param[in]           initiator       Boolean whether to initiate parameters
844  *
845  * \retval      GSS_S_COMPLETE  success
846  * \retval      GSS_S_FAILURE   failure
847  */
848 uint32_t sk_gen_params(struct sk_cred *skc, const bool initiator)
849 {
850         if (initiator)
851                 return sk_gen_initiator_params(skc);
852
853         return sk_gen_responder_params(skc);
854 }
855
856 /**
857  * Convert SK hash algorithm into openssl message digest
858  *
859  * \param[in,out]       alg             SK hash algorithm
860  *
861  * \retval              EVP_MD
862  */
863 static inline const EVP_MD *sk_hash_to_evp_md(enum sk_hmac_alg alg)
864 {
865         switch (alg) {
866         case SK_HMAC_SHA256:
867                 return EVP_sha256();
868         case SK_HMAC_SHA512:
869                 return EVP_sha512();
870         default:
871                 return EVP_md_null();
872         }
873 }
874
875 /**
876  * Signs (via HMAC) the parameters used only in the key initialization protocol.
877  *
878  * \param[in]           key             Key to use for HMAC
879  * \param[in]           bufs            Array of gss_buffer_desc to generate
880  *                                      HMAC for
881  * \param[in]           numbufs         Number of buffers in array
882  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
883  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
884  *                                      in this gss_buffer_desc.  Caller must
885  *                                      free this.
886  *
887  * \retval      0       success
888  * \retval      -1      failure
889  */
890 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
891                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
892 {
893         HMAC_CTX hctx;
894         unsigned int hashlen = EVP_MD_size(hash_alg);
895         int i;
896         int rc = -1;
897
898         if (hash_alg == EVP_md_null()) {
899                 printerr(0, "Invalid hash algorithm\n");
900                 return -1;
901         }
902
903         HMAC_CTX_init(&hctx);
904
905         hmac->length = hashlen;
906         hmac->value = malloc(hashlen);
907         if (!hmac->value) {
908                 printerr(0, "Failed to allocate memory for HMAC\n");
909                 goto out;
910         }
911
912 #ifdef HAVE_VOID_OPENSSL_HMAC_FUNCS
913         HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL);
914         for (i = 0; i < numbufs; i++)
915                 HMAC_Update(&hctx, bufs[i].value, bufs[i].length);
916         HMAC_Final(&hctx, hmac->value, &hashlen);
917 #else
918         if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
919                 printerr(0, "Failed to init HMAC\n");
920                 goto out;
921         }
922
923         for (i = 0; i < numbufs; i++) {
924                 if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
925                         printerr(0, "Failed to update HMAC\n");
926                         goto out;
927                 }
928         }
929
930         /* The result gets populated in hmac */
931         if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
932                 printerr(0, "Failed to finalize HMAC\n");
933                 goto out;
934         }
935 #endif
936
937         if (hmac->length != hashlen) {
938                 printerr(0, "HMAC size does not match expected\n");
939                 goto out;
940         }
941
942         rc = 0;
943 out:
944         HMAC_CTX_cleanup(&hctx);
945         return rc;
946 }
947
948 /**
949  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
950  * and verifies against \a hmac.
951  *
952  * \param[in]   skc             Shared key credentials
953  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
954  * \param[in]   numbufs         Number of buffers in array
955  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
956  * \param[in]   hmac            HMAC to verify against
957  *
958  * \retval      GSS_S_COMPLETE  success (match)
959  * \retval      gss error       failure
960  */
961 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
962                         const int numbufs, const EVP_MD *hash_alg,
963                         gss_buffer_desc *hmac)
964 {
965         gss_buffer_desc bufs_hmac;
966         int rc;
967
968         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
969                          &bufs_hmac)) {
970                 printerr(0, "Failed to sign buffers to verify HMAC\n");
971                 if (bufs_hmac.value)
972                         free(bufs_hmac.value);
973                 return GSS_S_FAILURE;
974         }
975
976         if (hmac->length != bufs_hmac.length) {
977                 printerr(0, "Invalid HMAC size\n");
978                 free(bufs_hmac.value);
979                 return GSS_S_BAD_SIG;
980         }
981
982         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
983         free(bufs_hmac.value);
984
985         if (rc)
986                 return GSS_S_BAD_SIG;
987
988         return GSS_S_COMPLETE;
989 }
990
991 /**
992  * Cleanup an sk_cred freeing any resources
993  *
994  * \param[in,out]       skc     Shared key credentials to free
995  */
996 void sk_free_cred(struct sk_cred *skc)
997 {
998         if (skc->sc_p.value)
999                 free(skc->sc_p.value);
1000         if (skc->sc_pub_key.value)
1001                 free(skc->sc_pub_key.value);
1002         if (skc->sc_tgt.value)
1003                 free(skc->sc_tgt.value);
1004         if (skc->sc_nodemap_hash.value)
1005                 free(skc->sc_nodemap_hash.value);
1006         if (skc->sc_hmac.value)
1007                 free(skc->sc_hmac.value);
1008
1009         /* Overwrite keys and IV before freeing */
1010         if (skc->sc_dh_shared_key.value) {
1011                 memset(skc->sc_dh_shared_key.value, 0,
1012                        skc->sc_dh_shared_key.length);
1013                 free(skc->sc_dh_shared_key.value);
1014         }
1015         if (skc->sc_kctx.skc_shared_key.value) {
1016                 memset(skc->sc_kctx.skc_shared_key.value, 0,
1017                        skc->sc_kctx.skc_shared_key.length);
1018                 free(skc->sc_kctx.skc_shared_key.value);
1019         }
1020         if (skc->sc_kctx.skc_iv.value) {
1021                 memset(skc->sc_kctx.skc_iv.value, 0,
1022                        skc->sc_kctx.skc_iv.length);
1023                 free(skc->sc_kctx.skc_iv.value);
1024         }
1025         if (skc->sc_kctx.skc_session_key.value) {
1026                 memset(skc->sc_kctx.skc_session_key.value, 0,
1027                        skc->sc_kctx.skc_session_key.length);
1028                 free(skc->sc_kctx.skc_session_key.value);
1029         }
1030
1031         if (skc->sc_params)
1032                 DH_free(skc->sc_params);
1033
1034         free(skc);
1035 }
1036
1037 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1038  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1039  * (Sep 2014) Section 5.5.1
1040  *
1041  * \param[in,out]       skc             Shared key credentials structure with
1042  *
1043  * \return      -1              failure
1044  * \return      0               success
1045  */
1046 int sk_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1047            gss_buffer_desc *key_binding_input)
1048 {
1049         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1050         gss_buffer_desc *session_key = &kctx->skc_session_key;
1051         gss_buffer_desc bufs[4];
1052         gss_buffer_desc tmp_hash;
1053         char *skp;
1054         size_t remain;
1055         size_t bytes;
1056         uint32_t counter;
1057         int i;
1058         int rc = -1;
1059
1060         /* No keys computed unless privacy mode is in use */
1061         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1062                 return 0;
1063
1064         session_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1065         session_key->value = malloc(session_key->length);
1066         if (!session_key->value) {
1067                 printerr(0, "Failed to allocate memory for session key\n");
1068                 return rc;
1069         }
1070
1071         /* Use the HMAC algorithm provided by in the shared key file to derive
1072          * a session key.  eg: HMAC(key, msg)
1073          * key: the shared key provided in the shared key file
1074          * msg is the bytes in the following order:
1075          * 1. big_endian(counter)
1076          * 2. DH shared key
1077          * 3. Clients NIDs
1078          * 4. key_binding_input */
1079         bufs[0].value = &counter;
1080         bufs[0].length = sizeof(counter);
1081         bufs[1] = skc->sc_dh_shared_key;
1082         bufs[2].value = &client_nid;
1083         bufs[2].length = sizeof(client_nid);
1084         bufs[3] = *key_binding_input;
1085
1086         remain = session_key->length;
1087         skp = session_key->value;
1088         i = 0;
1089         while (remain > 0) {
1090                 counter = be32toh(i++);
1091                 rc = sk_sign_bufs(&kctx->skc_shared_key, bufs, 4,
1092                              sk_hash_to_evp_md(kctx->skc_hmac_alg), &tmp_hash);
1093                 if (rc) {
1094                         free(tmp_hash.value);
1095                         return rc;
1096                 }
1097
1098                 LASSERT(sk_hmac_types[kctx->skc_hmac_alg].sht_bytes ==
1099                         tmp_hash.length);
1100
1101                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1102                 memcpy(skp, tmp_hash.value, bytes);
1103                 free(tmp_hash.value);
1104                 remain -= bytes;
1105                 skp += bytes;
1106         }
1107
1108         return 0;
1109 }
1110
1111 /**
1112  * Computes a session key based on the DH parameters from the host and its peer
1113  *
1114  * \param[in,out]       skc             Shared key credentials structure with
1115  *                                      the session key populated with the
1116  *                                      compute key
1117  * \param[in]           pub_key         Public key returned from peer in
1118  *                                      gss_buffer_desc
1119  * \return      gss error               failure
1120  * \return      GSS_S_COMPLETE          success
1121  */
1122 uint32_t sk_compute_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1123 {
1124         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1125         BIGNUM *remote_pub_key;
1126         int status;
1127         uint32_t rc = GSS_S_FAILURE;
1128
1129         /* No keys computed unless privacy mode is in use */
1130         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1131                 return GSS_S_COMPLETE;
1132
1133         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1134         if (!remote_pub_key) {
1135                 printerr(0, "Failed to convert binary to BIGNUM\n");
1136                 return rc;
1137         }
1138
1139         dh_shared->length = DH_size(skc->sc_params);
1140         dh_shared->value = malloc(dh_shared->length);
1141         if (!dh_shared->value) {
1142                 printerr(0, "Failed to allocate memory for computed shared "
1143                          "secret key\n");
1144                 goto out_err;
1145         }
1146
1147         /* This compute the shared key from the DHKE */
1148         status = DH_compute_key(dh_shared->value, remote_pub_key,
1149                                 skc->sc_params);
1150         if (status == -1) {
1151                 printerr(0, "DH_compute_key() failed: %s\n",
1152                          ERR_error_string(ERR_get_error(), NULL));
1153                 goto out_err;
1154         } else if (status < dh_shared->length) {
1155                 printerr(0, "DH_compute_key() returned a short key of %d "
1156                          "bytes, expected: %zu\n", status, dh_shared->length);
1157                 rc = GSS_S_DEFECTIVE_TOKEN;
1158                 goto out_err;
1159         }
1160
1161         rc = GSS_S_COMPLETE;
1162
1163 out_err:
1164         BN_free(remote_pub_key);
1165         return rc;
1166 }
1167
1168 /**
1169  * Creates a serialized buffer for the kernel in the order of struct
1170  * sk_kernel_ctx.
1171  *
1172  * \param[in,out]       skc             Shared key credentials structure
1173  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1174  *                                      Caller must free this buffer.
1175  *
1176  * \return      0       success
1177  * \return      -1      failure
1178  */
1179 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1180 {
1181         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1182         char *p, *end;
1183         size_t bufsize;
1184
1185         bufsize = sizeof(*kctx) + kctx->skc_session_key.length +
1186                   kctx->skc_iv.length + kctx->skc_shared_key.length;
1187
1188         ctx_token->value = malloc(bufsize);
1189         if (!ctx_token->value)
1190                 return -1;
1191         ctx_token->length = bufsize;
1192
1193         p = ctx_token->value;
1194         end = p + ctx_token->length;
1195
1196         if (WRITE_BYTES(&p, end, kctx->skc_version))
1197                 return -1;
1198         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1199                 return -1;
1200         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1201                 return -1;
1202         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1203                 return -1;
1204         if (write_buffer(&p, end, &kctx->skc_shared_key))
1205                 return -1;
1206         if (write_buffer(&p, end, &kctx->skc_iv))
1207                 return -1;
1208         if (write_buffer(&p, end, &kctx->skc_session_key))
1209                 return -1;
1210
1211         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1212
1213         return 0;
1214 }
1215
1216 /**
1217  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1218  * up to \a numbufs.  Memory is allocated for each value and length
1219  * will be populated with the length
1220  *
1221  * \param[in,out]       bufs    Array of gss_buffer_descs
1222  * \param[in,out]       numbufs number of gss_buffer_desc in array
1223  * \param[in]           ns      netstring to decode
1224  *
1225  * \return      buffers populated       success
1226  * \return      -1                      failure
1227  */
1228 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1229 {
1230         char *ptr = ns->value;
1231         size_t remain = ns->length;
1232         unsigned int size;
1233         int digits;
1234         int sep;
1235         int rc;
1236         int i;
1237
1238         for (i = 0; i < numbufs; i++) {
1239                 /* read the size of first buffer */
1240                 rc = sscanf(ptr, "%9u", &size);
1241                 if (rc < 1)
1242                         goto out_err;
1243                 digits = (size) ? ceil(log10(size + 1)) : 1;
1244
1245                 /* sep of current string */
1246                 sep = size + digits + 2;
1247
1248                 /* check to make sure it's valid */
1249                 if (remain < sep || ptr[digits] != ':' ||
1250                     ptr[sep - 1] != ',')
1251                         goto out_err;
1252
1253                 bufs[i].length = size;
1254                 if (size == 0) {
1255                         bufs[i].value = NULL;
1256                 } else {
1257                         bufs[i].value = malloc(size);
1258                         if (!bufs[i].value)
1259                                 goto out_err;
1260                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1261                 }
1262
1263                 remain -= sep;
1264                 ptr += sep;
1265         }
1266
1267         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1268         return i;
1269
1270 out_err:
1271         while (i-- > 0) {
1272                 if (bufs[i].value)
1273                         free(bufs[i].value);
1274                 bufs[i].length = 0;
1275         }
1276         return -1;
1277 }
1278
1279 /**
1280  * Creates a netstring in a gss_buffer_desc that consists of all
1281  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1282  * as binary as it can contain null characters.
1283  *
1284  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1285  * \param[in]           numbufs         Number of buffers in array
1286  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1287  *                                      netstring
1288  *
1289  * \return      -1      failure
1290  * \return      0       success
1291  */
1292 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1293                         gss_buffer_desc *ns)
1294 {
1295         unsigned char *ptr;
1296         int size = 0;
1297         int rc;
1298         int i;
1299
1300         /* size of string in decimal, string size, colon, and comma */
1301         for (i = 0; i < numbufs; i++) {
1302
1303                 if (bufs[i].length == 0)
1304                         size += 3;
1305                 else
1306                         size += ceil(log10(bufs[i].length + 1)) +
1307                                 bufs[i].length + 2;
1308         }
1309
1310         ns->length = size;
1311         ns->value = malloc(ns->length);
1312         if (!ns->value) {
1313                 ns->length = 0;
1314                 return -1;
1315         }
1316
1317         ptr = ns->value;
1318         for (i = 0; i < numbufs; i++) {
1319                 /* size */
1320                 rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
1321                 ptr += rc;
1322
1323                 /* contents */
1324                 memcpy(ptr, bufs[i].value, bufs[i].length);
1325                 ptr += bufs[i].length;
1326
1327                 /* delimeter */
1328                 *ptr++ = ',';
1329
1330                 size -= bufs[i].length + rc + 1;
1331
1332                 /* should not happen */
1333                 if (size < 0)
1334                         abort();
1335         }
1336
1337         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1338         return 0;
1339 }