Whamcloud - gitweb
LU-12861 libcfs: provide an scnprintf and start using it
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42 #include <libcfs/util/string.h>
43
44 #include "sk_utils.h"
45 #include "write_bytes.h"
46
47 #define SK_PBKDF2_ITERATIONS 10000
48
49 #ifdef _NEW_BUILD_
50 # include "lgss_utils.h"
51 #else
52 # include "gss_util.h"
53 # include "gss_oids.h"
54 # include "err_util.h"
55 #endif
56
57 #ifdef _ERR_UTIL_H_
58 /**
59  * Initializes logging
60  * \param[in]   program         Program name to output
61  * \param[in]   verbose         Verbose flag
62  * \param[in]   fg              Whether or not to run in foreground
63  *
64  */
65 void sk_init_logging(char *program, int verbose, int fg)
66 {
67         initerr(program, verbose, fg);
68 }
69 #endif
70
71 /**
72  * Loads the key from \a filename and returns the struct sk_keyfile_config.
73  * It should be freed by the caller.
74  *
75  * \param[in]   filename                Disk or key payload data
76  *
77  * \return      sk_keyfile_config       sucess
78  * \return      NULL                    failure
79  */
80 struct sk_keyfile_config *sk_read_file(char *filename)
81 {
82         struct sk_keyfile_config *config;
83         char *ptr;
84         size_t rc;
85         size_t remain;
86         int fd;
87
88         config = malloc(sizeof(*config));
89         if (!config) {
90                 printerr(0, "Failed to allocate memory for config\n");
91                 return NULL;
92         }
93
94         /* allow standard input override */
95         if (strcmp(filename, "-") == 0)
96                 fd = STDIN_FILENO;
97         else
98                 fd = open(filename, O_RDONLY);
99
100         if (fd == -1) {
101                 printerr(0, "Error opening key file '%s': %s\n", filename,
102                          strerror(errno));
103                 goto out_free;
104         } else if (fd != STDIN_FILENO) {
105                 struct stat st;
106
107                 rc = fstat(fd, &st);
108                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
109                         fprintf(stderr, "warning: "
110                                 "secret key '%s' has insecure file mode %#o\n",
111                                 filename, st.st_mode);
112         }
113
114         ptr = (char *)config;
115         remain = sizeof(*config);
116         while (remain > 0) {
117                 rc = read(fd, ptr, remain);
118                 if (rc == -1) {
119                         if (errno == EINTR)
120                                 continue;
121                         printerr(0, "read() failed on %s: %s\n", filename,
122                                  strerror(errno));
123                         goto out_close;
124                 } else if (rc == 0) {
125                         printerr(0, "File %s does not have a complete key\n",
126                                  filename);
127                         goto out_close;
128                 }
129                 ptr += rc;
130                 remain -= rc;
131         }
132
133         if (fd != STDIN_FILENO)
134                 close(fd);
135         sk_config_disk_to_cpu(config);
136         return config;
137
138 out_close:
139         close(fd);
140 out_free:
141         free(config);
142         return NULL;
143 }
144
145 /**
146  * Checks if a key matching \a description is found in the keyring for
147  * logging purposes and then attempts to load \a payload of \a psize into a key
148  * with \a description.
149  *
150  * \param[in]   payload         Key payload
151  * \param[in]   psize           Payload size
152  * \param[in]   description     Description used for key in keyring
153  *
154  * \return      0       sucess
155  * \return      -1      failure
156  */
157 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
158                                 const char *description)
159 {
160         struct sk_keyfile_config payload;
161         key_serial_t key;
162
163         memcpy(&payload, skc, sizeof(*skc));
164
165         /* In the keyring use the disk layout so keyctl pipe can be used */
166         sk_config_cpu_to_disk(&payload);
167
168         /* Check to see if a key is already loaded matching description */
169         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
170         if (key != -1)
171                 printerr(2, "Key %d found in session keyring, replacing\n",
172                          key);
173
174         key = add_key("user", description, &payload, sizeof(payload),
175                       KEY_SPEC_USER_KEYRING);
176         if (key != -1)
177                 printerr(2, "Added key %d with description %s\n", key,
178                          description);
179         else
180                 printerr(0, "Failed to add key with %s\n", description);
181
182         return key;
183 }
184
185 /**
186  * Reads the key from \a path, verifies it and loads into the session keyring
187  * using a description determined by the the \a type.  Existing keys with the
188  * same description are replaced.
189  *
190  * \param[in]   path    Path to key file
191  * \param[in]   type    Type of key to load which determines the description
192  *
193  * \return      0       sucess
194  * \return      -1      failure
195  */
196 int sk_load_keyfile(char *path)
197 {
198         struct sk_keyfile_config *config;
199         char description[SK_DESCRIPTION_SIZE + 1];
200         struct stat buf;
201         int i;
202         int rc;
203         int rc2 = -1;
204
205         rc = stat(path, &buf);
206         if (rc == -1) {
207                 printerr(0, "stat() failed for file %s: %s\n", path,
208                          strerror(errno));
209                 return rc2;
210         }
211
212         config = sk_read_file(path);
213         if (!config)
214                 return rc2;
215
216         /* Similar to ssh, require adequate care of key files */
217         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
218                 printerr(0, "Shared key files must be read/writeable only by "
219                          "owner\n");
220                 return -1;
221         }
222
223         if (sk_validate_config(config))
224                 goto out;
225
226         /* The server side can have multiple key files per file system so
227          * the nodemap name is appended to the key description to uniquely
228          * identify it */
229         if (config->skc_type & SK_TYPE_MGS) {
230                 /* Any key can be an MGS key as long as we are told to use it */
231                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
232                               config->skc_nodemap);
233                 if (rc >= SK_DESCRIPTION_SIZE)
234                         goto out;
235                 if (sk_load_key(config, description) == -1)
236                         goto out;
237         }
238         if (config->skc_type & SK_TYPE_SERVER) {
239                 /* Server keys need to have the file system name in the key */
240                 if (!config->skc_fsname) {
241                         printerr(0, "Key configuration has no file system "
242                                  "attribute.  Can't load as server type\n");
243                         goto out;
244                 }
245                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
246                               config->skc_fsname, config->skc_nodemap);
247                 if (rc >= SK_DESCRIPTION_SIZE)
248                         goto out;
249                 if (sk_load_key(config, description) == -1)
250                         goto out;
251         }
252         if (config->skc_type & SK_TYPE_CLIENT) {
253                 /* Load client file system key */
254                 if (config->skc_fsname) {
255                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
256                                       "lustre:%s", config->skc_fsname);
257                         if (rc >= SK_DESCRIPTION_SIZE)
258                                 goto out;
259                         if (sk_load_key(config, description) == -1)
260                                 goto out;
261                 }
262
263                 /* Load client MGC keys */
264                 for (i = 0; i < MAX_MGSNIDS; i++) {
265                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
266                                 continue;
267                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
268                                       "lustre:MGC%s",
269                                       libcfs_nid2str(config->skc_mgsnids[i]));
270                         if (rc >= SK_DESCRIPTION_SIZE)
271                                 goto out;
272                         if (sk_load_key(config, description) == -1)
273                                 goto out;
274                 }
275         }
276
277         rc2 = 0;
278
279 out:
280         free(config);
281         return rc2;
282 }
283
284 /**
285  * Byte swaps config from cpu format to disk
286  *
287  * \param[in,out]       config          sk_keyfile_config to swap
288  */
289 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
290 {
291         int i;
292
293         if (!config)
294                 return;
295
296         config->skc_version = htobe32(config->skc_version);
297         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
298         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
299         config->skc_expire = htobe32(config->skc_expire);
300         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
301         config->skc_prime_bits = htobe32(config->skc_prime_bits);
302
303         for (i = 0; i < MAX_MGSNIDS; i++)
304                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
305 }
306
307 /**
308  * Byte swaps config from disk format to cpu
309  *
310  * \param[in,out]       config          sk_keyfile_config to swap
311  */
312 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
313 {
314         int i;
315
316         if (!config)
317                 return;
318
319         config->skc_version = be32toh(config->skc_version);
320         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
321         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
322         config->skc_expire = be32toh(config->skc_expire);
323         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
324         config->skc_prime_bits = be32toh(config->skc_prime_bits);
325
326         for (i = 0; i < MAX_MGSNIDS; i++)
327                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
328 }
329
330 /**
331  * Verifies the on key payload format is valid
332  *
333  * \param[in]   config          sk_keyfile_config
334  *
335  * \return      -1      failure
336  * \return      0       success
337  */
338 int sk_validate_config(const struct sk_keyfile_config *config)
339 {
340         int i;
341
342         if (!config) {
343                 printerr(0, "Null configuration passed\n");
344                 return -1;
345         }
346
347         if (config->skc_version != SK_CONF_VERSION) {
348                 printerr(0, "Invalid version\n");
349                 return -1;
350         }
351
352         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
353                 printerr(0, "Invalid HMAC algorithm\n");
354                 return -1;
355         }
356
357         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
358                 printerr(0, "Invalid crypt algorithm\n");
359                 return -1;
360         }
361
362         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
363                 /* Try to limit key expiration to some reasonable minimum and
364                  * also prevent values over INT_MAX because there appears
365                  * to be a type conversion issue */
366                 printerr(0, "Invalid expiration time should be between %d "
367                          "and %d\n", 60, INT_MAX);
368                 return -1;
369         }
370         if (config->skc_prime_bits % 8 != 0 ||
371             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
372                 printerr(0, "Invalid session key length must be a multiple of 8"
373                          " and less then %d bits\n",
374                          SK_MAX_P_BYTES * 8);
375                 return -1;
376         }
377         if (config->skc_shared_keylen % 8 != 0 ||
378             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
379                 printerr(0, "Invalid shared key max length must be a multiple "
380                          "of 8 and less then %d bits\n",
381                          SK_MAX_KEYLEN_BYTES * 8);
382                 return -1;
383         }
384
385         /* Check for terminating nulls on strings */
386         for (i = 0; i < sizeof(config->skc_fsname) &&
387              config->skc_fsname[i] != '\0';  i++)
388                 ; /* empty loop */
389         if (i == sizeof(config->skc_fsname)) {
390                 printerr(0, "File system name not null terminated\n");
391                 return -1;
392         }
393
394         for (i = 0; i < sizeof(config->skc_nodemap) &&
395              config->skc_nodemap[i] != '\0';  i++)
396                 ; /* empty loop */
397         if (i == sizeof(config->skc_nodemap)) {
398                 printerr(0, "Nodemap name not null terminated\n");
399                 return -1;
400         }
401
402         if (config->skc_type == SK_TYPE_INVALID) {
403                 printerr(0, "Invalid key type\n");
404                 return -1;
405         }
406
407         return 0;
408 }
409
410 /**
411  * Hashes \a string and places the hash in \a hash
412  * at \a hash
413  *
414  * \param[in]           string          Null terminated string to hash
415  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
416  * \param[in,out]       hash            gss_buffer_desc to hold the result
417  *
418  * \return      -1      failure
419  * \return      0       success
420  */
421 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
422                           gss_buffer_desc *hash)
423 {
424         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
425         size_t len = strlen(string);
426         unsigned int hashlen;
427
428         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
429                 goto out_err;
430         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
431                 goto out_err;
432         if (!EVP_DigestUpdate(ctx, string, len))
433                 goto out_err;
434         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
435                 goto out_err;
436
437         EVP_MD_CTX_destroy(ctx);
438         hash->length = hashlen;
439         return 0;
440
441 out_err:
442         EVP_MD_CTX_destroy(ctx);
443         return -1;
444 }
445
446 /**
447  * Hashes \a string and verifies the resulting hash matches the value
448  * in \a current_hash
449  *
450  * \param[in]           string          Null terminated string to hash
451  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
452  * \param[in,out]       current_hash    gss_buffer_desc to compare to
453  *
454  * \return      gss error       failure
455  * \return      GSS_S_COMPLETE  success
456  */
457 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
458                         const gss_buffer_desc *current_hash)
459 {
460         gss_buffer_desc hash;
461         unsigned char hashbuf[EVP_MAX_MD_SIZE];
462
463         hash.value = hashbuf;
464         hash.length = sizeof(hashbuf);
465
466         if (sk_hash_string(string, hash_alg, &hash))
467                 return GSS_S_FAILURE;
468         if (current_hash->length != hash.length)
469                 return GSS_S_DEFECTIVE_TOKEN;
470         if (memcmp(current_hash->value, hash.value, hash.length))
471                 return GSS_S_BAD_SIG;
472
473         return GSS_S_COMPLETE;
474 }
475
476 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
477                                        const char *mgsnid)
478 {
479         lnet_nid_t nid;
480         int i;
481
482         nid = libcfs_str2nid(mgsnid);
483         if (nid == LNET_NID_ANY)
484                 return 0;
485
486         for (i = 0; i < MAX_MGSNIDS; i++)
487                 if  (config->skc_mgsnids[i] == nid)
488                         return 1;
489         return 0;
490 }
491
492 /**
493  * Create an sk_cred structure populated with initial configuration info and the
494  * key.  \a tgt and \a nodemap are used in determining the expected key
495  * description so the key can be found by searching the keyring.
496  * This is done because there is no easy way to pass keys from the mount command
497  * all the way to the request_key call.  In addition any keys can be dynamically
498  * added to the keyrings and still found.  The keyring that needs to be used
499  * must be the session keyring.
500  *
501  * \param[in]   tgt             Target file system
502  * \param[in]   nodemap         Cluster name for the key.  This correlates to
503  *                              the nodemap name and is used by the server side.
504  *                              For the client this will be NULL.
505  * \param[in]   flags           Flags for the credentials
506  *
507  * \return      sk_cred Allocated struct sk_cred on success
508  * \return      NULL    failure
509  */
510 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
511                                const uint32_t flags)
512 {
513         struct sk_keyfile_config *config;
514         struct sk_kernel_ctx *kctx;
515         struct sk_cred *skc = NULL;
516         char description[SK_DESCRIPTION_SIZE + 1];
517         char fsname[MTI_NAME_MAXLEN + 1];
518         const char *mgsnid = NULL;
519         char *ptr;
520         long sk_key;
521         int keylen;
522         int len;
523         int rc;
524
525         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
526                  tgt, nodemap);
527
528         memset(description, 0, sizeof(description));
529         memset(fsname, 0, sizeof(fsname));
530
531         /* extract the file system name from target */
532         ptr = index(tgt, '-');
533         if (!ptr) {
534                 len = strlen(tgt);
535
536                 /* This must be an MGC target */
537                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
538                         printerr(0, "Invalid target name\n");
539                         return NULL;
540                 }
541                 mgsnid = tgt + 3;
542         } else {
543                 len = ptr - tgt;
544         }
545
546         if (len > MTI_NAME_MAXLEN) {
547                 printerr(0, "Invalid target name\n");
548                 return NULL;
549         }
550         memcpy(fsname, tgt, len);
551
552         if (nodemap) {
553                 if (mgsnid)
554                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
555                                       "lustre:MGS:%s", nodemap);
556                 else
557                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
558                                       "lustre:%s:%s", fsname, nodemap);
559         } else {
560                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
561                               fsname);
562         }
563
564         if (rc >= SK_DESCRIPTION_SIZE) {
565                 printerr(0, "Invalid key description\n");
566                 return NULL;
567         }
568
569         /* It may be a good idea to move Lustre keys to the gss_keyring
570          * (lgssc) type so that they expire when Lustre modules are removed.
571          * Unfortunately it can't be done at mount time because the mount
572          * syscall could trigger the Lustre modules to load and until that
573          * point we don't have a lgssc key type.
574          *
575          * TODO: Query the community for a consensus here  */
576         printerr(2, "Searching for key with description: %s\n", description);
577         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
578                                description, 0);
579         if (sk_key == -1) {
580                 printerr(1, "No key found for %s\n", description);
581                 return NULL;
582         }
583
584         keylen = keyctl_read_alloc(sk_key, (void **)&config);
585         if (keylen == -1) {
586                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
587                          strerror(errno));
588                 return NULL;
589         } else if (keylen != sizeof(*config)) {
590                 printerr(0, "Unexpected key size: %d returned for key %ld, "
591                          "expected %zu bytes\n",
592                          keylen, sk_key, sizeof(*config));
593                 goto out_err;
594         }
595
596         sk_config_disk_to_cpu(config);
597
598         if (sk_validate_config(config)) {
599                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
600                 goto out_err;
601         }
602
603         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
604                 printerr(0, "Target name does not match key's MGS NIDs\n");
605                 goto out_err;
606         }
607
608         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
609                 printerr(0, "Target name does not match key's file system\n");
610                 goto out_err;
611         }
612
613         skc = malloc(sizeof(*skc));
614         if (!skc) {
615                 printerr(0, "Failed to allocate memory for sk_cred\n");
616                 goto out_err;
617         }
618
619         /* this initializes all gss_buffer_desc to empty as well */
620         memset(skc, 0, sizeof(*skc));
621
622         skc->sc_flags = flags;
623         skc->sc_tgt.length = strlen(tgt) + 1;
624         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
625         if (!skc->sc_tgt.value) {
626                 printerr(0, "Failed to allocate memory for target\n");
627                 goto out_err;
628         }
629         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
630
631         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
632         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
633         if (!skc->sc_nodemap_hash.value) {
634                 printerr(0, "Failed to allocate memory for nodemap hash\n");
635                 goto out_err;
636         }
637
638         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
639                            &skc->sc_nodemap_hash)) {
640                 printerr(0, "Failed to generate hash for nodemap name\n");
641                 goto out_err;
642         }
643
644         kctx = &skc->sc_kctx;
645         kctx->skc_version = config->skc_version;
646         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
647         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
648         kctx->skc_expire = config->skc_expire;
649
650         /* key payload format is in bits, convert to bytes */
651         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
652         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
653         if (!kctx->skc_shared_key.value) {
654                 printerr(0, "Failed to allocate memory for shared key\n");
655                 goto out_err;
656         }
657         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
658                kctx->skc_shared_key.length);
659
660         skc->sc_p.length = config->skc_prime_bits / 8;
661         skc->sc_p.value = malloc(skc->sc_p.length);
662         if (!skc->sc_p.value) {
663                 printerr(0, "Failed to allocate p\n");
664                 goto out_err;
665         }
666         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
667
668         free(config);
669
670         return skc;
671
672 out_err:
673         sk_free_cred(skc);
674
675         free(config);
676         return NULL;
677 }
678
679 /**
680  * Populates the DH parameters for the DHKE
681  *
682  * \param[in,out]       skc             Shared key credentials structure to
683  *                                      populate with DH parameters
684  *
685  * \retval      GSS_S_COMPLETE  success
686  * \retval      GSS_S_FAILURE   failure
687  */
688 uint32_t sk_gen_params(struct sk_cred *skc)
689 {
690         uint32_t random;
691         BIGNUM *p, *g;
692         const BIGNUM *pub_key;
693         int rc;
694
695         /* Random value used by both the request and response as part of the
696          * key binding material.  This also should ensure we have unqiue
697          * tokens that are sent to the remote server which is important because
698          * the token is hashed for the sunrpc cache lookups and a failure there
699          * would cause connection attempts to fail indefinitely due to the large
700          * timeout value on the server side */
701         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
702                 printerr(0, "Failed to get data for random parameter: %s\n",
703                          ERR_error_string(ERR_get_error(), NULL));
704                 return GSS_S_FAILURE;
705         }
706
707         /* The random value will always be used in byte range operations
708          * so we keep it as big endian from this point on */
709         skc->sc_kctx.skc_host_random = random;
710
711         /* Populate DH parameters */
712         skc->sc_params = DH_new();
713         if (!skc->sc_params) {
714                 printerr(0, "Failed to allocate DH\n");
715                 return GSS_S_FAILURE;
716         }
717
718         p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
719         if (!p) {
720                 printerr(0, "Failed to convert binary to BIGNUM\n");
721                 return GSS_S_FAILURE;
722         }
723
724         /* We use a static generator for shared key */
725         g = BN_new();
726         if (!g) {
727                 printerr(0, "Failed to allocate new BIGNUM\n");
728                 return GSS_S_FAILURE;
729         }
730         if (BN_set_word(g, SK_GENERATOR) != 1) {
731                 printerr(0, "Failed to set g value for DH params\n");
732                 return GSS_S_FAILURE;
733         }
734
735         if (!DH_set0_pqg(skc->sc_params, p, NULL, g)) {
736                 printerr(0, "Failed to set pqg\n");
737                 return GSS_S_FAILURE;
738         }
739
740         /* Verify that we have a safe prime and valid generator */
741         if (DH_check(skc->sc_params, &rc) != 1) {
742                 printerr(0, "DH_check() failed: %d\n", rc);
743                 return GSS_S_FAILURE;
744         } else if (rc) {
745                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
746                 return GSS_S_FAILURE;
747         }
748
749         if (DH_generate_key(skc->sc_params) != 1) {
750                 printerr(0, "Failed to generate public DH key: %s\n",
751                          ERR_error_string(ERR_get_error(), NULL));
752                 return GSS_S_FAILURE;
753         }
754
755         DH_get0_key(skc->sc_params, &pub_key, NULL);
756         skc->sc_pub_key.length = BN_num_bytes(pub_key);
757         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
758         if (!skc->sc_pub_key.value) {
759                 printerr(0, "Failed to allocate memory for public key\n");
760                 return GSS_S_FAILURE;
761         }
762
763         BN_bn2bin(pub_key, skc->sc_pub_key.value);
764
765         return GSS_S_COMPLETE;
766 }
767
768 /**
769  * Convert SK hash algorithm into openssl message digest
770  *
771  * \param[in,out]       alg             SK hash algorithm
772  *
773  * \retval              EVP_MD
774  */
775 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
776 {
777         switch (alg) {
778         case CFS_HASH_ALG_SHA256:
779                 return EVP_sha256();
780         case CFS_HASH_ALG_SHA512:
781                 return EVP_sha512();
782         default:
783                 return EVP_md_null();
784         }
785 }
786
787 /**
788  * Signs (via HMAC) the parameters used only in the key initialization protocol.
789  *
790  * \param[in]           key             Key to use for HMAC
791  * \param[in]           bufs            Array of gss_buffer_desc to generate
792  *                                      HMAC for
793  * \param[in]           numbufs         Number of buffers in array
794  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
795  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
796  *                                      in this gss_buffer_desc.  Caller must
797  *                                      free this.
798  *
799  * \retval      0       success
800  * \retval      -1      failure
801  */
802 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
803                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
804 {
805         HMAC_CTX *hctx;
806         unsigned int hashlen = EVP_MD_size(hash_alg);
807         int i;
808         int rc = -1;
809
810         if (hash_alg == EVP_md_null()) {
811                 printerr(0, "Invalid hash algorithm\n");
812                 return -1;
813         }
814
815         hctx = HMAC_CTX_new();
816
817         hmac->length = hashlen;
818         hmac->value = malloc(hashlen);
819         if (!hmac->value) {
820                 printerr(0, "Failed to allocate memory for HMAC\n");
821                 goto out;
822         }
823
824         if (HMAC_Init_ex(hctx, key->value, key->length, hash_alg, NULL) != 1) {
825                 printerr(0, "Failed to init HMAC\n");
826                 goto out;
827         }
828
829         for (i = 0; i < numbufs; i++) {
830                 if (HMAC_Update(hctx, bufs[i].value, bufs[i].length) != 1) {
831                         printerr(0, "Failed to update HMAC\n");
832                         goto out;
833                 }
834         }
835
836         /* The result gets populated in hmac */
837         if (HMAC_Final(hctx, hmac->value, &hashlen) != 1) {
838                 printerr(0, "Failed to finalize HMAC\n");
839                 goto out;
840         }
841
842         if (hmac->length != hashlen) {
843                 printerr(0, "HMAC size does not match expected\n");
844                 goto out;
845         }
846
847         rc = 0;
848 out:
849         HMAC_CTX_free(hctx);
850         return rc;
851 }
852
853 /**
854  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
855  * and verifies against \a hmac.
856  *
857  * \param[in]   skc             Shared key credentials
858  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
859  * \param[in]   numbufs         Number of buffers in array
860  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
861  * \param[in]   hmac            HMAC to verify against
862  *
863  * \retval      GSS_S_COMPLETE  success (match)
864  * \retval      gss error       failure
865  */
866 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
867                         const int numbufs, const EVP_MD *hash_alg,
868                         gss_buffer_desc *hmac)
869 {
870         gss_buffer_desc bufs_hmac;
871         int rc;
872
873         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
874                          &bufs_hmac)) {
875                 printerr(0, "Failed to sign buffers to verify HMAC\n");
876                 if (bufs_hmac.value)
877                         free(bufs_hmac.value);
878                 return GSS_S_FAILURE;
879         }
880
881         if (hmac->length != bufs_hmac.length) {
882                 printerr(0, "Invalid HMAC size\n");
883                 free(bufs_hmac.value);
884                 return GSS_S_BAD_SIG;
885         }
886
887         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
888         free(bufs_hmac.value);
889
890         if (rc)
891                 return GSS_S_BAD_SIG;
892
893         return GSS_S_COMPLETE;
894 }
895
896 /**
897  * Cleanup an sk_cred freeing any resources
898  *
899  * \param[in,out]       skc     Shared key credentials to free
900  */
901 void sk_free_cred(struct sk_cred *skc)
902 {
903         if (!skc)
904                 return;
905
906         if (skc->sc_p.value)
907                 free(skc->sc_p.value);
908         if (skc->sc_pub_key.value)
909                 free(skc->sc_pub_key.value);
910         if (skc->sc_tgt.value)
911                 free(skc->sc_tgt.value);
912         if (skc->sc_nodemap_hash.value)
913                 free(skc->sc_nodemap_hash.value);
914         if (skc->sc_hmac.value)
915                 free(skc->sc_hmac.value);
916
917         /* Overwrite keys and IV before freeing */
918         if (skc->sc_dh_shared_key.value) {
919                 memset(skc->sc_dh_shared_key.value, 0,
920                        skc->sc_dh_shared_key.length);
921                 free(skc->sc_dh_shared_key.value);
922         }
923         if (skc->sc_kctx.skc_hmac_key.value) {
924                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
925                        skc->sc_kctx.skc_hmac_key.length);
926                 free(skc->sc_kctx.skc_hmac_key.value);
927         }
928         if (skc->sc_kctx.skc_encrypt_key.value) {
929                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
930                        skc->sc_kctx.skc_encrypt_key.length);
931                 free(skc->sc_kctx.skc_encrypt_key.value);
932         }
933         if (skc->sc_kctx.skc_shared_key.value) {
934                 memset(skc->sc_kctx.skc_shared_key.value, 0,
935                        skc->sc_kctx.skc_shared_key.length);
936                 free(skc->sc_kctx.skc_shared_key.value);
937         }
938         if (skc->sc_kctx.skc_session_key.value) {
939                 memset(skc->sc_kctx.skc_session_key.value, 0,
940                        skc->sc_kctx.skc_session_key.length);
941                 free(skc->sc_kctx.skc_session_key.value);
942         }
943
944         if (skc->sc_params)
945                 DH_free(skc->sc_params);
946
947         free(skc);
948         skc = NULL;
949 }
950
951 /* This function handles key derivation using the hash algorithm specified in
952  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
953  * \a origin_key to produce a \a derived_key.  The first element of the
954  * key_binding_bufs array is reserved for the counter used in the KDF.  The
955  * derived key in \a derived_key could differ in size from \a origin_key and
956  * must be populated with the expected size and a valid buffer to hold the
957  * contents.
958  *
959  * If the derived key size is greater than the HMAC algorithm size it will be
960  * a done using several iterations of a counter and the key binding bufs.
961  *
962  * If the size is smaller it will take copy the first N bytes necessary to
963  * fill the derived key. */
964 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
965            gss_buffer_desc *key_binding_bufs, int numbufs,
966            enum cfs_crypto_hash_alg hmac_alg)
967 {
968         size_t remain;
969         size_t bytes;
970         uint32_t counter;
971         char *keydata;
972         gss_buffer_desc tmp_hash;
973         int i;
974         int rc;
975
976         if (numbufs < 1)
977                 return -EINVAL;
978
979         /* Use a counter as the first buffer followed by the key binding
980          * buffers in the event we need more than one a single cycle to
981          * produced a symmetric key large enough in size */
982         key_binding_bufs[0].value = &counter;
983         key_binding_bufs[0].length = sizeof(counter);
984
985         remain = derived_key->length;
986         keydata = derived_key->value;
987         i = 0;
988         while (remain > 0) {
989                 counter = htobe32(i++);
990                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
991                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
992                 if (rc) {
993                         if (tmp_hash.value)
994                                 free(tmp_hash.value);
995                         return rc;
996                 }
997
998                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
999                         free(tmp_hash.value);
1000                         return -EINVAL;
1001                 }
1002
1003                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1004                 memcpy(keydata, tmp_hash.value, bytes);
1005                 free(tmp_hash.value);
1006                 remain -= bytes;
1007                 keydata += bytes;
1008         }
1009
1010         return 0;
1011 }
1012
1013 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1014  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1015  * (Sep 2014) Section 5.5.1
1016  *
1017  * \param[in,out]       skc             Shared key credentials structure with
1018  *
1019  * \return      -1              failure
1020  * \return      0               success
1021  */
1022 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1023                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1024 {
1025         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1026         gss_buffer_desc *session_key = &kctx->skc_session_key;
1027         gss_buffer_desc bufs[5];
1028         enum cfs_crypto_crypt_alg crypt_alg;
1029         int rc = -1;
1030
1031         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1032         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1033         session_key->value = malloc(session_key->length);
1034         if (!session_key->value) {
1035                 printerr(0, "Failed to allocate memory for session key\n");
1036                 return rc;
1037         }
1038
1039         /* Key binding info ordering
1040          * 1. Reserved for counter
1041          * 1. DH shared key
1042          * 2. Client's NIDs
1043          * 3. Client's token
1044          * 4. Server's token */
1045         bufs[0].value = NULL;
1046         bufs[0].length = 0;
1047         bufs[1] = skc->sc_dh_shared_key;
1048         bufs[2].value = &client_nid;
1049         bufs[2].length = sizeof(client_nid);
1050         bufs[3] = *client_token;
1051         bufs[4] = *server_token;
1052
1053         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1054                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1055 }
1056
1057 /* Uses the session key to create an HMAC key and encryption key.  In
1058  * integrity mode the session key used to generate the HMAC key uses
1059  * session information which is available on the wire but by creating
1060  * a session based HMAC key we can prevent potential replay as both the
1061  * client and server have random numbers used as part of the key creation.
1062  *
1063  * The keys used for integrity and privacy are formulated as below using
1064  * the session key that is the output of the key derivation function.  The
1065  * HMAC algorithm is determined by the shared key algorithm selected in the
1066  * key file.
1067  *
1068  * For ski mode:
1069  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1070  *
1071  * For skpi mode:
1072  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1073  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1074  *
1075  * \param[in,out]       skc             Shared key credentials structure with
1076  *
1077  * \return      -1              failure
1078  * \return      0               success
1079  */
1080 int sk_compute_keys(struct sk_cred *skc)
1081 {
1082         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1083         gss_buffer_desc *session_key = &kctx->skc_session_key;
1084         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1085         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1086         enum cfs_crypto_hash_alg hmac_alg;
1087         enum cfs_crypto_crypt_alg crypt_alg;
1088         char *encrypt = "Encrypt";
1089         char *integrity = "Integrity";
1090         int rc;
1091
1092         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1093         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1094         hmac_key->value = malloc(hmac_key->length);
1095         if (!hmac_key->value)
1096                 return -ENOMEM;
1097
1098         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1099                                session_key->length, SK_PBKDF2_ITERATIONS,
1100                                sk_hash_to_evp_md(hmac_alg),
1101                                hmac_key->length, hmac_key->value);
1102         if (rc == 0)
1103                 return -EINVAL;
1104
1105         /* Encryption key is only populated in privacy mode */
1106         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1107                 return 0;
1108
1109         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1110         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1111         encrypt_key->value = malloc(encrypt_key->length);
1112         if (!encrypt_key->value)
1113                 return -ENOMEM;
1114
1115         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1116                                session_key->length, SK_PBKDF2_ITERATIONS,
1117                                sk_hash_to_evp_md(hmac_alg),
1118                                encrypt_key->length, encrypt_key->value);
1119         if (rc == 0)
1120                 return -EINVAL;
1121
1122         return 0;
1123 }
1124
1125 /**
1126  * Computes a session key based on the DH parameters from the host and its peer
1127  *
1128  * \param[in,out]       skc             Shared key credentials structure with
1129  *                                      the session key populated with the
1130  *                                      compute key
1131  * \param[in]           pub_key         Public key returned from peer in
1132  *                                      gss_buffer_desc
1133  * \return      gss error               failure
1134  * \return      GSS_S_COMPLETE          success
1135  */
1136 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1137 {
1138         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1139         BIGNUM *remote_pub_key;
1140         int status;
1141         uint32_t rc = GSS_S_FAILURE;
1142
1143         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1144         if (!remote_pub_key) {
1145                 printerr(0, "Failed to convert binary to BIGNUM\n");
1146                 return rc;
1147         }
1148
1149         dh_shared->length = DH_size(skc->sc_params);
1150         dh_shared->value = malloc(dh_shared->length);
1151         if (!dh_shared->value) {
1152                 printerr(0, "Failed to allocate memory for computed shared "
1153                          "secret key\n");
1154                 goto out_err;
1155         }
1156
1157         /* This compute the shared key from the DHKE */
1158         status = DH_compute_key(dh_shared->value, remote_pub_key,
1159                                 skc->sc_params);
1160         if (status == -1) {
1161                 printerr(0, "DH_compute_key() failed: %s\n",
1162                          ERR_error_string(ERR_get_error(), NULL));
1163                 goto out_err;
1164         } else if (status < dh_shared->length) {
1165                 printerr(0, "DH_compute_key() returned a short key of %d "
1166                          "bytes, expected: %zu\n", status, dh_shared->length);
1167                 rc = GSS_S_DEFECTIVE_TOKEN;
1168                 goto out_err;
1169         }
1170
1171         rc = GSS_S_COMPLETE;
1172
1173 out_err:
1174         BN_free(remote_pub_key);
1175         return rc;
1176 }
1177
1178 /**
1179  * Creates a serialized buffer for the kernel in the order of struct
1180  * sk_kernel_ctx.
1181  *
1182  * \param[in,out]       skc             Shared key credentials structure
1183  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1184  *                                      Caller must free this buffer.
1185  *
1186  * \return      0       success
1187  * \return      -1      failure
1188  */
1189 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1190 {
1191         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1192         char *p, *end;
1193         size_t bufsize;
1194
1195         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1196                   kctx->skc_encrypt_key.length;
1197
1198         ctx_token->value = malloc(bufsize);
1199         if (!ctx_token->value)
1200                 return -1;
1201         ctx_token->length = bufsize;
1202
1203         p = ctx_token->value;
1204         end = p + ctx_token->length;
1205
1206         if (WRITE_BYTES(&p, end, kctx->skc_version))
1207                 return -1;
1208         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1209                 return -1;
1210         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1211                 return -1;
1212         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1213                 return -1;
1214         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1215                 return -1;
1216         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1217                 return -1;
1218         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1219                 return -1;
1220         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1221                 return -1;
1222
1223         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1224
1225         return 0;
1226 }
1227
1228 /**
1229  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1230  * up to \a numbufs.  Memory is allocated for each value and length
1231  * will be populated with the length
1232  *
1233  * \param[in,out]       bufs    Array of gss_buffer_descs
1234  * \param[in,out]       numbufs number of gss_buffer_desc in array
1235  * \param[in]           ns      netstring to decode
1236  *
1237  * \return      buffers populated       success
1238  * \return      -1                      failure
1239  */
1240 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1241 {
1242         char *ptr = ns->value;
1243         size_t remain = ns->length;
1244         unsigned int size;
1245         int digits;
1246         int sep;
1247         int rc;
1248         int i;
1249
1250         for (i = 0; i < numbufs; i++) {
1251                 /* read the size of first buffer */
1252                 rc = sscanf(ptr, "%9u", &size);
1253                 if (rc < 1)
1254                         goto out_err;
1255                 digits = (size) ? ceil(log10(size + 1)) : 1;
1256
1257                 /* sep of current string */
1258                 sep = size + digits + 2;
1259
1260                 /* check to make sure it's valid */
1261                 if (remain < sep || ptr[digits] != ':' ||
1262                     ptr[sep - 1] != ',')
1263                         goto out_err;
1264
1265                 bufs[i].length = size;
1266                 if (size == 0) {
1267                         bufs[i].value = NULL;
1268                 } else {
1269                         bufs[i].value = malloc(size);
1270                         if (!bufs[i].value)
1271                                 goto out_err;
1272                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1273                 }
1274
1275                 remain -= sep;
1276                 ptr += sep;
1277         }
1278
1279         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1280         return i;
1281
1282 out_err:
1283         while (i-- > 0) {
1284                 if (bufs[i].value)
1285                         free(bufs[i].value);
1286                 bufs[i].length = 0;
1287         }
1288         return -1;
1289 }
1290
1291 /**
1292  * Creates a netstring in a gss_buffer_desc that consists of all
1293  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1294  * as binary as it can contain null characters.
1295  *
1296  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1297  * \param[in]           numbufs         Number of buffers in array
1298  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1299  *                                      netstring
1300  *
1301  * \return      -1      failure
1302  * \return      0       success
1303  */
1304 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1305                         gss_buffer_desc *ns)
1306 {
1307         unsigned char *ptr;
1308         int size = 0;
1309         int rc;
1310         int i;
1311
1312         /* size of string in decimal, string size, colon, and comma */
1313         for (i = 0; i < numbufs; i++) {
1314
1315                 if (bufs[i].length == 0)
1316                         size += 3;
1317                 else
1318                         size += ceil(log10(bufs[i].length + 1)) +
1319                                 bufs[i].length + 2;
1320         }
1321
1322         ns->length = size;
1323         ns->value = malloc(ns->length);
1324         if (!ns->value) {
1325                 ns->length = 0;
1326                 return -1;
1327         }
1328
1329         ptr = ns->value;
1330         for (i = 0; i < numbufs; i++) {
1331                 /* size */
1332                 rc = scnprintf((char *) ptr, size, "%zu:", bufs[i].length);
1333                 ptr += rc;
1334
1335                 /* contents */
1336                 memcpy(ptr, bufs[i].value, bufs[i].length);
1337                 ptr += bufs[i].length;
1338
1339                 /* delimeter */
1340                 *ptr++ = ',';
1341
1342                 size -= bufs[i].length + rc + 1;
1343
1344                 /* should not happen */
1345                 if (size < 0)
1346                         abort();
1347         }
1348
1349         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1350         return 0;
1351 }