Whamcloud - gitweb
4b4a911d8f62a8b7ddaa6c11b695a36f78242b09
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42
43 #include "sk_utils.h"
44 #include "write_bytes.h"
45
46 #define SK_PBKDF2_ITERATIONS 10000
47
48 #ifdef _NEW_BUILD_
49 # include "lgss_utils.h"
50 #else
51 # include "gss_util.h"
52 # include "gss_oids.h"
53 # include "err_util.h"
54 #endif
55
56 #ifdef _ERR_UTIL_H_
57 /**
58  * Initializes logging
59  * \param[in]   program         Program name to output
60  * \param[in]   verbose         Verbose flag
61  * \param[in]   fg              Whether or not to run in foreground
62  *
63  */
64 void sk_init_logging(char *program, int verbose, int fg)
65 {
66         initerr(program, verbose, fg);
67 }
68 #endif
69
70 /**
71  * Loads the key from \a filename and returns the struct sk_keyfile_config.
72  * It should be freed by the caller.
73  *
74  * \param[in]   filename                Disk or key payload data
75  *
76  * \return      sk_keyfile_config       sucess
77  * \return      NULL                    failure
78  */
79 struct sk_keyfile_config *sk_read_file(char *filename)
80 {
81         struct sk_keyfile_config *config;
82         char *ptr;
83         size_t rc;
84         size_t remain;
85         int fd;
86
87         config = malloc(sizeof(*config));
88         if (!config) {
89                 printerr(0, "Failed to allocate memory for config\n");
90                 return NULL;
91         }
92
93         /* allow standard input override */
94         if (strcmp(filename, "-") == 0)
95                 fd = STDIN_FILENO;
96         else
97                 fd = open(filename, O_RDONLY);
98
99         if (fd == -1) {
100                 printerr(0, "Error opening key file '%s': %s\n", filename,
101                          strerror(errno));
102                 goto out_free;
103         } else if (fd != STDIN_FILENO) {
104                 struct stat st;
105
106                 rc = fstat(fd, &st);
107                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
108                         fprintf(stderr, "warning: "
109                                 "secret key '%s' has insecure file mode %#o\n",
110                                 filename, st.st_mode);
111         }
112
113         ptr = (char *)config;
114         remain = sizeof(*config);
115         while (remain > 0) {
116                 rc = read(fd, ptr, remain);
117                 if (rc == -1) {
118                         if (errno == EINTR)
119                                 continue;
120                         printerr(0, "read() failed on %s: %s\n", filename,
121                                  strerror(errno));
122                         goto out_close;
123                 } else if (rc == 0) {
124                         printerr(0, "File %s does not have a complete key\n",
125                                  filename);
126                         goto out_close;
127                 }
128                 ptr += rc;
129                 remain -= rc;
130         }
131
132         if (fd != STDIN_FILENO)
133                 close(fd);
134         sk_config_disk_to_cpu(config);
135         return config;
136
137 out_close:
138         close(fd);
139 out_free:
140         free(config);
141         return NULL;
142 }
143
144 /**
145  * Checks if a key matching \a description is found in the keyring for
146  * logging purposes and then attempts to load \a payload of \a psize into a key
147  * with \a description.
148  *
149  * \param[in]   payload         Key payload
150  * \param[in]   psize           Payload size
151  * \param[in]   description     Description used for key in keyring
152  *
153  * \return      0       sucess
154  * \return      -1      failure
155  */
156 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
157                                 const char *description)
158 {
159         struct sk_keyfile_config payload;
160         key_serial_t key;
161
162         memcpy(&payload, skc, sizeof(*skc));
163
164         /* In the keyring use the disk layout so keyctl pipe can be used */
165         sk_config_cpu_to_disk(&payload);
166
167         /* Check to see if a key is already loaded matching description */
168         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
169         if (key != -1)
170                 printerr(2, "Key %d found in session keyring, replacing\n",
171                          key);
172
173         key = add_key("user", description, &payload, sizeof(payload),
174                       KEY_SPEC_USER_KEYRING);
175         if (key != -1)
176                 printerr(2, "Added key %d with description %s\n", key,
177                          description);
178         else
179                 printerr(0, "Failed to add key with %s\n", description);
180
181         return key;
182 }
183
184 /**
185  * Reads the key from \a path, verifies it and loads into the session keyring
186  * using a description determined by the the \a type.  Existing keys with the
187  * same description are replaced.
188  *
189  * \param[in]   path    Path to key file
190  * \param[in]   type    Type of key to load which determines the description
191  *
192  * \return      0       sucess
193  * \return      -1      failure
194  */
195 int sk_load_keyfile(char *path)
196 {
197         struct sk_keyfile_config *config;
198         char description[SK_DESCRIPTION_SIZE + 1];
199         struct stat buf;
200         int i;
201         int rc;
202         int rc2 = -1;
203
204         rc = stat(path, &buf);
205         if (rc == -1) {
206                 printerr(0, "stat() failed for file %s: %s\n", path,
207                          strerror(errno));
208                 return rc2;
209         }
210
211         config = sk_read_file(path);
212         if (!config)
213                 return rc2;
214
215         /* Similar to ssh, require adequate care of key files */
216         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
217                 printerr(0, "Shared key files must be read/writeable only by "
218                          "owner\n");
219                 return -1;
220         }
221
222         if (sk_validate_config(config))
223                 goto out;
224
225         /* The server side can have multiple key files per file system so
226          * the nodemap name is appended to the key description to uniquely
227          * identify it */
228         if (config->skc_type & SK_TYPE_MGS) {
229                 /* Any key can be an MGS key as long as we are told to use it */
230                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
231                               config->skc_nodemap);
232                 if (rc >= SK_DESCRIPTION_SIZE)
233                         goto out;
234                 if (sk_load_key(config, description) == -1)
235                         goto out;
236         }
237         if (config->skc_type & SK_TYPE_SERVER) {
238                 /* Server keys need to have the file system name in the key */
239                 if (!config->skc_fsname) {
240                         printerr(0, "Key configuration has no file system "
241                                  "attribute.  Can't load as server type\n");
242                         goto out;
243                 }
244                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
245                               config->skc_fsname, config->skc_nodemap);
246                 if (rc >= SK_DESCRIPTION_SIZE)
247                         goto out;
248                 if (sk_load_key(config, description) == -1)
249                         goto out;
250         }
251         if (config->skc_type & SK_TYPE_CLIENT) {
252                 /* Load client file system key */
253                 if (config->skc_fsname) {
254                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
255                                       "lustre:%s", config->skc_fsname);
256                         if (rc >= SK_DESCRIPTION_SIZE)
257                                 goto out;
258                         if (sk_load_key(config, description) == -1)
259                                 goto out;
260                 }
261
262                 /* Load client MGC keys */
263                 for (i = 0; i < MAX_MGSNIDS; i++) {
264                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
265                                 continue;
266                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
267                                       "lustre:MGC%s",
268                                       libcfs_nid2str(config->skc_mgsnids[i]));
269                         if (rc >= SK_DESCRIPTION_SIZE)
270                                 goto out;
271                         if (sk_load_key(config, description) == -1)
272                                 goto out;
273                 }
274         }
275
276         rc2 = 0;
277
278 out:
279         free(config);
280         return rc2;
281 }
282
283 /**
284  * Byte swaps config from cpu format to disk
285  *
286  * \param[in,out]       config          sk_keyfile_config to swap
287  */
288 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
289 {
290         int i;
291
292         if (!config)
293                 return;
294
295         config->skc_version = htobe32(config->skc_version);
296         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
297         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
298         config->skc_expire = htobe32(config->skc_expire);
299         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
300         config->skc_prime_bits = htobe32(config->skc_prime_bits);
301
302         for (i = 0; i < MAX_MGSNIDS; i++)
303                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
304
305         return;
306 }
307
308 /**
309  * Byte swaps config from disk format to cpu
310  *
311  * \param[in,out]       config          sk_keyfile_config to swap
312  */
313 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
314 {
315         int i;
316
317         if (!config)
318                 return;
319
320         config->skc_version = be32toh(config->skc_version);
321         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
322         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
323         config->skc_expire = be32toh(config->skc_expire);
324         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
325         config->skc_prime_bits = be32toh(config->skc_prime_bits);
326
327         for (i = 0; i < MAX_MGSNIDS; i++)
328                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
329
330         return;
331 }
332
333 /**
334  * Verifies the on key payload format is valid
335  *
336  * \param[in]   config          sk_keyfile_config
337  *
338  * \return      -1      failure
339  * \return      0       success
340  */
341 int sk_validate_config(const struct sk_keyfile_config *config)
342 {
343         int i;
344
345         if (!config) {
346                 printerr(0, "Null configuration passed\n");
347                 return -1;
348         }
349
350         if (config->skc_version != SK_CONF_VERSION) {
351                 printerr(0, "Invalid version\n");
352                 return -1;
353         }
354
355         if (config->skc_hmac_alg == SK_HMAC_INVALID) {
356                 printerr(0, "Invalid HMAC algorithm\n");
357                 return -1;
358         }
359
360         if (config->skc_crypt_alg == SK_CRYPT_INVALID) {
361                 printerr(0, "Invalid crypt algorithm\n");
362                 return -1;
363         }
364
365         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
366                 /* Try to limit key expiration to some reasonable minimum and
367                  * also prevent values over INT_MAX because there appears
368                  * to be a type conversion issue */
369                 printerr(0, "Invalid expiration time should be between %d "
370                          "and %d\n", 60, INT_MAX);
371                 return -1;
372         }
373         if (config->skc_prime_bits % 8 != 0 ||
374             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
375                 printerr(0, "Invalid session key length must be a multiple of 8"
376                          " and less then %d bits\n",
377                          SK_MAX_P_BYTES * 8);
378                 return -1;
379         }
380         if (config->skc_shared_keylen % 8 != 0 ||
381             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
382                 printerr(0, "Invalid shared key max length must be a multiple "
383                          "of 8 and less then %d bits\n",
384                          SK_MAX_KEYLEN_BYTES * 8);
385                 return -1;
386         }
387
388         /* Check for terminating nulls on strings */
389         for (i = 0; i < sizeof(config->skc_fsname) &&
390              config->skc_fsname[i] != '\0';  i++)
391                 ; /* empty loop */
392         if (i == sizeof(config->skc_fsname)) {
393                 printerr(0, "File system name not null terminated\n");
394                 return -1;
395         }
396
397         for (i = 0; i < sizeof(config->skc_nodemap) &&
398              config->skc_nodemap[i] != '\0';  i++)
399                 ; /* empty loop */
400         if (i == sizeof(config->skc_nodemap)) {
401                 printerr(0, "Nodemap name not null terminated\n");
402                 return -1;
403         }
404
405         if (config->skc_type == SK_TYPE_INVALID) {
406                 printerr(0, "Invalid key type\n");
407                 return -1;
408         }
409
410         return 0;
411 }
412
413 /**
414  * Hashes \a string and places the hash in \a hash
415  * at \a hash
416  *
417  * \param[in]           string          Null terminated string to hash
418  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
419  * \param[in,out]       hash            gss_buffer_desc to hold the result
420  *
421  * \return      -1      failure
422  * \return      0       success
423  */
424 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
425                           gss_buffer_desc *hash)
426 {
427         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
428         size_t len = strlen(string);
429         unsigned int hashlen;
430
431         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
432                 goto out_err;
433         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
434                 goto out_err;
435         if (!EVP_DigestUpdate(ctx, string, len))
436                 goto out_err;
437         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
438                 goto out_err;
439
440         EVP_MD_CTX_destroy(ctx);
441         hash->length = hashlen;
442         return 0;
443
444 out_err:
445         EVP_MD_CTX_destroy(ctx);
446         return -1;
447 }
448
449 /**
450  * Hashes \a string and verifies the resulting hash matches the value
451  * in \a current_hash
452  *
453  * \param[in]           string          Null terminated string to hash
454  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
455  * \param[in,out]       current_hash    gss_buffer_desc to compare to
456  *
457  * \return      gss error       failure
458  * \return      GSS_S_COMPLETE  success
459  */
460 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
461                         const gss_buffer_desc *current_hash)
462 {
463         gss_buffer_desc hash;
464         unsigned char hashbuf[EVP_MAX_MD_SIZE];
465
466         hash.value = hashbuf;
467         hash.length = sizeof(hashbuf);
468
469         if (sk_hash_string(string, hash_alg, &hash))
470                 return GSS_S_FAILURE;
471         if (current_hash->length != hash.length)
472                 return GSS_S_DEFECTIVE_TOKEN;
473         if (memcmp(current_hash->value, hash.value, hash.length))
474                 return GSS_S_BAD_SIG;
475
476         return GSS_S_COMPLETE;
477 }
478
479 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
480                                        const char *mgsnid)
481 {
482         lnet_nid_t nid;
483         int i;
484
485         nid = libcfs_str2nid(mgsnid);
486         if (nid == LNET_NID_ANY)
487                 return 0;
488
489         for (i = 0; i < MAX_MGSNIDS; i++)
490                 if  (config->skc_mgsnids[i] == nid)
491                         return 1;
492         return 0;
493 }
494
495 /**
496  * Create an sk_cred structure populated with initial configuration info and the
497  * key.  \a tgt and \a nodemap are used in determining the expected key
498  * description so the key can be found by searching the keyring.
499  * This is done because there is no easy way to pass keys from the mount command
500  * all the way to the request_key call.  In addition any keys can be dynamically
501  * added to the keyrings and still found.  The keyring that needs to be used
502  * must be the session keyring.
503  *
504  * \param[in]   tgt             Target file system
505  * \param[in]   nodemap         Cluster name for the key.  This correlates to
506  *                              the nodemap name and is used by the server side.
507  *                              For the client this will be NULL.
508  * \param[in]   flags           Flags for the credentials
509  *
510  * \return      sk_cred Allocated struct sk_cred on success
511  * \return      NULL    failure
512  */
513 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
514                                const uint32_t flags)
515 {
516         struct sk_keyfile_config *config;
517         struct sk_kernel_ctx *kctx;
518         struct sk_cred *skc = NULL;
519         char description[SK_DESCRIPTION_SIZE + 1];
520         char fsname[MTI_NAME_MAXLEN + 1];
521         const char *mgsnid = NULL;
522         char *ptr;
523         long sk_key;
524         int keylen;
525         int len;
526         int rc;
527
528         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
529                  tgt, nodemap);
530
531         memset(description, 0, sizeof(description));
532         memset(fsname, 0, sizeof(fsname));
533
534         /* extract the file system name from target */
535         ptr = index(tgt, '-');
536         if (!ptr) {
537                 len = strlen(tgt);
538
539                 /* This must be an MGC target */
540                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
541                         printerr(0, "Invalid target name\n");
542                         return NULL;
543                 }
544                 mgsnid = tgt + 3;
545         } else {
546                 len = ptr - tgt;
547         }
548
549         if (len > MTI_NAME_MAXLEN) {
550                 printerr(0, "Invalid target name\n");
551                 return NULL;
552         }
553         memcpy(fsname, tgt, len);
554
555         if (nodemap) {
556                 if (mgsnid)
557                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
558                                       "lustre:MGS:%s", nodemap);
559                 else
560                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
561                                       "lustre:%s:%s", fsname, nodemap);
562         } else {
563                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
564                               fsname);
565         }
566
567         if (rc >= SK_DESCRIPTION_SIZE) {
568                 printerr(0, "Invalid key description\n");
569                 return NULL;
570         }
571
572         /* It may be a good idea to move Lustre keys to the gss_keyring
573          * (lgssc) type so that they expire when Lustre modules are removed.
574          * Unfortunately it can't be done at mount time because the mount
575          * syscall could trigger the Lustre modules to load and until that
576          * point we don't have a lgssc key type.
577          *
578          * TODO: Query the community for a consensus here  */
579         printerr(2, "Searching for key with description: %s\n", description);
580         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
581                                description, 0);
582         if (sk_key == -1) {
583                 printerr(1, "No key found for %s\n", description);
584                 return NULL;
585         }
586
587         keylen = keyctl_read_alloc(sk_key, (void **)&config);
588         if (keylen == -1) {
589                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
590                          strerror(errno));
591                 return NULL;
592         } else if (keylen != sizeof(*config)) {
593                 printerr(0, "Unexpected key size: %d returned for key %ld, "
594                          "expected %zu bytes\n",
595                          keylen, sk_key, sizeof(*config));
596                 goto out_err;
597         }
598
599         sk_config_disk_to_cpu(config);
600
601         if (sk_validate_config(config)) {
602                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
603                 goto out_err;
604         }
605
606         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
607                 printerr(0, "Target name does not match key's MGS NIDs\n");
608                 goto out_err;
609         }
610
611         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
612                 printerr(0, "Target name does not match key's file system\n");
613                 goto out_err;
614         }
615
616         skc = malloc(sizeof(*skc));
617         if (!skc) {
618                 printerr(0, "Failed to allocate memory for sk_cred\n");
619                 goto out_err;
620         }
621
622         /* this initializes all gss_buffer_desc to empty as well */
623         memset(skc, 0, sizeof(*skc));
624
625         skc->sc_flags = flags;
626         skc->sc_tgt.length = strlen(tgt) + 1;
627         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
628         if (!skc->sc_tgt.value) {
629                 printerr(0, "Failed to allocate memory for target\n");
630                 goto out_err;
631         }
632         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
633
634         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
635         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
636         if (!skc->sc_nodemap_hash.value) {
637                 printerr(0, "Failed to allocate memory for nodemap hash\n");
638                 goto out_err;
639         }
640
641         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
642                            &skc->sc_nodemap_hash)) {
643                 printerr(0, "Failed to generate hash for nodemap name\n");
644                 goto out_err;
645         }
646
647         kctx = &skc->sc_kctx;
648         kctx->skc_version = config->skc_version;
649         strcpy(kctx->skc_hmac_alg, sk_hmac2name(config->skc_hmac_alg));
650         strcpy(kctx->skc_crypt_alg, sk_crypt2name(config->skc_crypt_alg));
651         kctx->skc_expire = config->skc_expire;
652
653         /* key payload format is in bits, convert to bytes */
654         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
655         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
656         if (!kctx->skc_shared_key.value) {
657                 printerr(0, "Failed to allocate memory for shared key\n");
658                 goto out_err;
659         }
660         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
661                kctx->skc_shared_key.length);
662
663         skc->sc_p.length = config->skc_prime_bits / 8;
664         skc->sc_p.value = malloc(skc->sc_p.length);
665         if (!skc->sc_p.value) {
666                 printerr(0, "Failed to allocate p\n");
667                 goto out_err;
668         }
669         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
670
671         free(config);
672
673         return skc;
674
675 out_err:
676         sk_free_cred(skc);
677
678         free(config);
679         return NULL;
680 }
681
682 /**
683  * Populates the DH parameters for the DHKE
684  *
685  * \param[in,out]       skc             Shared key credentials structure to
686  *                                      populate with DH parameters
687  *
688  * \retval      GSS_S_COMPLETE  success
689  * \retval      GSS_S_FAILURE   failure
690  */
691 uint32_t sk_gen_params(struct sk_cred *skc)
692 {
693         uint32_t random;
694         int rc;
695
696         /* Random value used by both the request and response as part of the
697          * key binding material.  This also should ensure we have unqiue
698          * tokens that are sent to the remote server which is important because
699          * the token is hashed for the sunrpc cache lookups and a failure there
700          * would cause connection attempts to fail indefinitely due to the large
701          * timeout value on the server side */
702         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
703                 printerr(0, "Failed to get data for random parameter: %s\n",
704                          ERR_error_string(ERR_get_error(), NULL));
705                 return GSS_S_FAILURE;
706         }
707
708         /* The random value will always be used in byte range operations
709          * so we keep it as big endian from this point on */
710         skc->sc_kctx.skc_host_random = random;
711
712         /* Populate DH parameters */
713         skc->sc_params = DH_new();
714         if (!skc->sc_params) {
715                 printerr(0, "Failed to allocate DH\n");
716                 return GSS_S_FAILURE;
717         }
718
719         skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
720         if (!skc->sc_params->p) {
721                 printerr(0, "Failed to convert binary to BIGNUM\n");
722                 return GSS_S_FAILURE;
723         }
724
725         /* We use a static generator for shared key */
726         skc->sc_params->g = BN_new();
727         if (!skc->sc_params->g) {
728                 printerr(0, "Failed to allocate new BIGNUM\n");
729                 return GSS_S_FAILURE;
730         }
731         if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
732                 printerr(0, "Failed to set g value for DH params\n");
733                 return GSS_S_FAILURE;
734         }
735
736         /* Verify that we have a safe prime and valid generator */
737         if (DH_check(skc->sc_params, &rc) != 1) {
738                 printerr(0, "DH_check() failed: %d\n", rc);
739                 return GSS_S_FAILURE;
740         } else if (rc) {
741                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
742                 return GSS_S_FAILURE;
743         }
744
745         if (DH_generate_key(skc->sc_params) != 1) {
746                 printerr(0, "Failed to generate public DH key: %s\n",
747                          ERR_error_string(ERR_get_error(), NULL));
748                 return GSS_S_FAILURE;
749         }
750
751         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
752         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
753         if (!skc->sc_pub_key.value) {
754                 printerr(0, "Failed to allocate memory for public key\n");
755                 return GSS_S_FAILURE;
756         }
757
758         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
759
760         return GSS_S_COMPLETE;
761 }
762
763 /**
764  * Convert SK hash algorithm into openssl message digest
765  *
766  * \param[in,out]       alg             SK hash algorithm
767  *
768  * \retval              EVP_MD
769  */
770 static inline const EVP_MD *sk_hash_to_evp_md(enum cfs_crypto_hash_alg alg)
771 {
772         switch (alg) {
773         case CFS_HASH_ALG_SHA256:
774                 return EVP_sha256();
775         case CFS_HASH_ALG_SHA512:
776                 return EVP_sha512();
777         default:
778                 return EVP_md_null();
779         }
780 }
781
782 /**
783  * Signs (via HMAC) the parameters used only in the key initialization protocol.
784  *
785  * \param[in]           key             Key to use for HMAC
786  * \param[in]           bufs            Array of gss_buffer_desc to generate
787  *                                      HMAC for
788  * \param[in]           numbufs         Number of buffers in array
789  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
790  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
791  *                                      in this gss_buffer_desc.  Caller must
792  *                                      free this.
793  *
794  * \retval      0       success
795  * \retval      -1      failure
796  */
797 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
798                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
799 {
800         HMAC_CTX hctx;
801         unsigned int hashlen = EVP_MD_size(hash_alg);
802         int i;
803         int rc = -1;
804
805         if (hash_alg == EVP_md_null()) {
806                 printerr(0, "Invalid hash algorithm\n");
807                 return -1;
808         }
809
810         HMAC_CTX_init(&hctx);
811
812         hmac->length = hashlen;
813         hmac->value = malloc(hashlen);
814         if (!hmac->value) {
815                 printerr(0, "Failed to allocate memory for HMAC\n");
816                 goto out;
817         }
818
819         if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
820                 printerr(0, "Failed to init HMAC\n");
821                 goto out;
822         }
823
824         for (i = 0; i < numbufs; i++) {
825                 if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
826                         printerr(0, "Failed to update HMAC\n");
827                         goto out;
828                 }
829         }
830
831         /* The result gets populated in hmac */
832         if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
833                 printerr(0, "Failed to finalize HMAC\n");
834                 goto out;
835         }
836
837         if (hmac->length != hashlen) {
838                 printerr(0, "HMAC size does not match expected\n");
839                 goto out;
840         }
841
842         rc = 0;
843 out:
844         HMAC_CTX_cleanup(&hctx);
845         return rc;
846 }
847
848 /**
849  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
850  * and verifies against \a hmac.
851  *
852  * \param[in]   skc             Shared key credentials
853  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
854  * \param[in]   numbufs         Number of buffers in array
855  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
856  * \param[in]   hmac            HMAC to verify against
857  *
858  * \retval      GSS_S_COMPLETE  success (match)
859  * \retval      gss error       failure
860  */
861 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
862                         const int numbufs, const EVP_MD *hash_alg,
863                         gss_buffer_desc *hmac)
864 {
865         gss_buffer_desc bufs_hmac;
866         int rc;
867
868         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
869                          &bufs_hmac)) {
870                 printerr(0, "Failed to sign buffers to verify HMAC\n");
871                 if (bufs_hmac.value)
872                         free(bufs_hmac.value);
873                 return GSS_S_FAILURE;
874         }
875
876         if (hmac->length != bufs_hmac.length) {
877                 printerr(0, "Invalid HMAC size\n");
878                 free(bufs_hmac.value);
879                 return GSS_S_BAD_SIG;
880         }
881
882         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
883         free(bufs_hmac.value);
884
885         if (rc)
886                 return GSS_S_BAD_SIG;
887
888         return GSS_S_COMPLETE;
889 }
890
891 /**
892  * Cleanup an sk_cred freeing any resources
893  *
894  * \param[in,out]       skc     Shared key credentials to free
895  */
896 void sk_free_cred(struct sk_cred *skc)
897 {
898         if (!skc)
899                 return;
900
901         if (skc->sc_p.value)
902                 free(skc->sc_p.value);
903         if (skc->sc_pub_key.value)
904                 free(skc->sc_pub_key.value);
905         if (skc->sc_tgt.value)
906                 free(skc->sc_tgt.value);
907         if (skc->sc_nodemap_hash.value)
908                 free(skc->sc_nodemap_hash.value);
909         if (skc->sc_hmac.value)
910                 free(skc->sc_hmac.value);
911
912         /* Overwrite keys and IV before freeing */
913         if (skc->sc_dh_shared_key.value) {
914                 memset(skc->sc_dh_shared_key.value, 0,
915                        skc->sc_dh_shared_key.length);
916                 free(skc->sc_dh_shared_key.value);
917         }
918         if (skc->sc_kctx.skc_hmac_key.value) {
919                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
920                        skc->sc_kctx.skc_hmac_key.length);
921                 free(skc->sc_kctx.skc_hmac_key.value);
922         }
923         if (skc->sc_kctx.skc_encrypt_key.value) {
924                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
925                        skc->sc_kctx.skc_encrypt_key.length);
926                 free(skc->sc_kctx.skc_encrypt_key.value);
927         }
928         if (skc->sc_kctx.skc_shared_key.value) {
929                 memset(skc->sc_kctx.skc_shared_key.value, 0,
930                        skc->sc_kctx.skc_shared_key.length);
931                 free(skc->sc_kctx.skc_shared_key.value);
932         }
933         if (skc->sc_kctx.skc_session_key.value) {
934                 memset(skc->sc_kctx.skc_session_key.value, 0,
935                        skc->sc_kctx.skc_session_key.length);
936                 free(skc->sc_kctx.skc_session_key.value);
937         }
938
939         if (skc->sc_params)
940                 DH_free(skc->sc_params);
941
942         free(skc);
943         skc = NULL;
944 }
945
946 /* This function handles key derivation using the hash algorithm specified in
947  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
948  * \a origin_key to produce a \a derived_key.  The first element of the
949  * key_binding_bufs array is reserved for the counter used in the KDF.  The
950  * derived key in \a derived_key could differ in size from \a origin_key and
951  * must be populated with the expected size and a valid buffer to hold the
952  * contents.
953  *
954  * If the derived key size is greater than the HMAC algorithm size it will be
955  * a done using several iterations of a counter and the key binding bufs.
956  *
957  * If the size is smaller it will take copy the first N bytes necessary to
958  * fill the derived key. */
959 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
960            gss_buffer_desc *key_binding_bufs, int numbufs,
961            enum cfs_crypto_hash_alg hmac_alg)
962 {
963         size_t remain;
964         size_t bytes;
965         uint32_t counter;
966         char *keydata;
967         gss_buffer_desc tmp_hash;
968         int i;
969         int rc;
970
971         if (numbufs < 1)
972                 return -EINVAL;
973
974         /* Use a counter as the first buffer followed by the key binding
975          * buffers in the event we need more than one a single cycle to
976          * produced a symmetric key large enough in size */
977         key_binding_bufs[0].value = &counter;
978         key_binding_bufs[0].length = sizeof(counter);
979
980         remain = derived_key->length;
981         keydata = derived_key->value;
982         i = 0;
983         while (remain > 0) {
984                 counter = htobe32(i++);
985                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
986                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
987                 if (rc) {
988                         if (tmp_hash.value)
989                                 free(tmp_hash.value);
990                         return rc;
991                 }
992
993                 if (cfs_crypto_hash_digestsize(hmac_alg) != tmp_hash.length) {
994                         free(tmp_hash.value);
995                         return -EINVAL;
996                 }
997
998                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
999                 memcpy(keydata, tmp_hash.value, bytes);
1000                 free(tmp_hash.value);
1001                 remain -= bytes;
1002                 keydata += bytes;
1003         }
1004
1005         return 0;
1006 }
1007
1008 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1009  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1010  * (Sep 2014) Section 5.5.1
1011  *
1012  * \param[in,out]       skc             Shared key credentials structure with
1013  *
1014  * \return      -1              failure
1015  * \return      0               success
1016  */
1017 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1018                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1019 {
1020         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1021         gss_buffer_desc *session_key = &kctx->skc_session_key;
1022         gss_buffer_desc bufs[5];
1023         enum cfs_crypto_crypt_alg crypt_alg;
1024         int rc = -1;
1025
1026         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1027         session_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1028         session_key->value = malloc(session_key->length);
1029         if (!session_key->value) {
1030                 printerr(0, "Failed to allocate memory for session key\n");
1031                 return rc;
1032         }
1033
1034         /* Key binding info ordering
1035          * 1. Reserved for counter
1036          * 1. DH shared key
1037          * 2. Client's NIDs
1038          * 3. Client's token
1039          * 4. Server's token */
1040         bufs[0].value = NULL;
1041         bufs[0].length = 0;
1042         bufs[1] = skc->sc_dh_shared_key;
1043         bufs[2].value = &client_nid;
1044         bufs[2].length = sizeof(client_nid);
1045         bufs[3] = *client_token;
1046         bufs[4] = *server_token;
1047
1048         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1049                       5, cfs_crypto_hash_alg(kctx->skc_hmac_alg));
1050 }
1051
1052 /* Uses the session key to create an HMAC key and encryption key.  In
1053  * integrity mode the session key used to generate the HMAC key uses
1054  * session information which is available on the wire but by creating
1055  * a session based HMAC key we can prevent potential replay as both the
1056  * client and server have random numbers used as part of the key creation.
1057  *
1058  * The keys used for integrity and privacy are formulated as below using
1059  * the session key that is the output of the key derivation function.  The
1060  * HMAC algorithm is determined by the shared key algorithm selected in the
1061  * key file.
1062  *
1063  * For ski mode:
1064  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1065  *
1066  * For skpi mode:
1067  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1068  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1069  *
1070  * \param[in,out]       skc             Shared key credentials structure with
1071  *
1072  * \return      -1              failure
1073  * \return      0               success
1074  */
1075 int sk_compute_keys(struct sk_cred *skc)
1076 {
1077         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1078         gss_buffer_desc *session_key = &kctx->skc_session_key;
1079         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1080         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1081         enum cfs_crypto_hash_alg hmac_alg;
1082         enum cfs_crypto_crypt_alg crypt_alg;
1083         char *encrypt = "Encrypt";
1084         char *integrity = "Integrity";
1085         int rc;
1086
1087         hmac_alg = cfs_crypto_hash_alg(kctx->skc_hmac_alg);
1088         hmac_key->length = cfs_crypto_hash_digestsize(hmac_alg);
1089         hmac_key->value = malloc(hmac_key->length);
1090         if (!hmac_key->value)
1091                 return -ENOMEM;
1092
1093         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1094                                session_key->length, SK_PBKDF2_ITERATIONS,
1095                                sk_hash_to_evp_md(hmac_alg),
1096                                hmac_key->length, hmac_key->value);
1097         if (rc == 0)
1098                 return -EINVAL;
1099
1100         /* Encryption key is only populated in privacy mode */
1101         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1102                 return 0;
1103
1104         crypt_alg = cfs_crypto_crypt_alg(kctx->skc_crypt_alg);
1105         encrypt_key->length = cfs_crypto_crypt_keysize(crypt_alg);
1106         encrypt_key->value = malloc(encrypt_key->length);
1107         if (!encrypt_key->value)
1108                 return -ENOMEM;
1109
1110         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1111                                session_key->length, SK_PBKDF2_ITERATIONS,
1112                                sk_hash_to_evp_md(hmac_alg),
1113                                encrypt_key->length, encrypt_key->value);
1114         if (rc == 0)
1115                 return -EINVAL;
1116
1117         return 0;
1118 }
1119
1120 /**
1121  * Computes a session key based on the DH parameters from the host and its peer
1122  *
1123  * \param[in,out]       skc             Shared key credentials structure with
1124  *                                      the session key populated with the
1125  *                                      compute key
1126  * \param[in]           pub_key         Public key returned from peer in
1127  *                                      gss_buffer_desc
1128  * \return      gss error               failure
1129  * \return      GSS_S_COMPLETE          success
1130  */
1131 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1132 {
1133         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1134         BIGNUM *remote_pub_key;
1135         int status;
1136         uint32_t rc = GSS_S_FAILURE;
1137
1138         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1139         if (!remote_pub_key) {
1140                 printerr(0, "Failed to convert binary to BIGNUM\n");
1141                 return rc;
1142         }
1143
1144         dh_shared->length = DH_size(skc->sc_params);
1145         dh_shared->value = malloc(dh_shared->length);
1146         if (!dh_shared->value) {
1147                 printerr(0, "Failed to allocate memory for computed shared "
1148                          "secret key\n");
1149                 goto out_err;
1150         }
1151
1152         /* This compute the shared key from the DHKE */
1153         status = DH_compute_key(dh_shared->value, remote_pub_key,
1154                                 skc->sc_params);
1155         if (status == -1) {
1156                 printerr(0, "DH_compute_key() failed: %s\n",
1157                          ERR_error_string(ERR_get_error(), NULL));
1158                 goto out_err;
1159         } else if (status < dh_shared->length) {
1160                 printerr(0, "DH_compute_key() returned a short key of %d "
1161                          "bytes, expected: %zu\n", status, dh_shared->length);
1162                 rc = GSS_S_DEFECTIVE_TOKEN;
1163                 goto out_err;
1164         }
1165
1166         rc = GSS_S_COMPLETE;
1167
1168 out_err:
1169         BN_free(remote_pub_key);
1170         return rc;
1171 }
1172
1173 /**
1174  * Creates a serialized buffer for the kernel in the order of struct
1175  * sk_kernel_ctx.
1176  *
1177  * \param[in,out]       skc             Shared key credentials structure
1178  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1179  *                                      Caller must free this buffer.
1180  *
1181  * \return      0       success
1182  * \return      -1      failure
1183  */
1184 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1185 {
1186         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1187         char *p, *end;
1188         size_t bufsize;
1189
1190         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1191                   kctx->skc_encrypt_key.length;
1192
1193         ctx_token->value = malloc(bufsize);
1194         if (!ctx_token->value)
1195                 return -1;
1196         ctx_token->length = bufsize;
1197
1198         p = ctx_token->value;
1199         end = p + ctx_token->length;
1200
1201         if (WRITE_BYTES(&p, end, kctx->skc_version))
1202                 return -1;
1203         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1204                 return -1;
1205         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1206                 return -1;
1207         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1208                 return -1;
1209         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1210                 return -1;
1211         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1212                 return -1;
1213         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1214                 return -1;
1215         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1216                 return -1;
1217
1218         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1219
1220         return 0;
1221 }
1222
1223 /**
1224  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1225  * up to \a numbufs.  Memory is allocated for each value and length
1226  * will be populated with the length
1227  *
1228  * \param[in,out]       bufs    Array of gss_buffer_descs
1229  * \param[in,out]       numbufs number of gss_buffer_desc in array
1230  * \param[in]           ns      netstring to decode
1231  *
1232  * \return      buffers populated       success
1233  * \return      -1                      failure
1234  */
1235 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1236 {
1237         char *ptr = ns->value;
1238         size_t remain = ns->length;
1239         unsigned int size;
1240         int digits;
1241         int sep;
1242         int rc;
1243         int i;
1244
1245         for (i = 0; i < numbufs; i++) {
1246                 /* read the size of first buffer */
1247                 rc = sscanf(ptr, "%9u", &size);
1248                 if (rc < 1)
1249                         goto out_err;
1250                 digits = (size) ? ceil(log10(size + 1)) : 1;
1251
1252                 /* sep of current string */
1253                 sep = size + digits + 2;
1254
1255                 /* check to make sure it's valid */
1256                 if (remain < sep || ptr[digits] != ':' ||
1257                     ptr[sep - 1] != ',')
1258                         goto out_err;
1259
1260                 bufs[i].length = size;
1261                 if (size == 0) {
1262                         bufs[i].value = NULL;
1263                 } else {
1264                         bufs[i].value = malloc(size);
1265                         if (!bufs[i].value)
1266                                 goto out_err;
1267                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1268                 }
1269
1270                 remain -= sep;
1271                 ptr += sep;
1272         }
1273
1274         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1275         return i;
1276
1277 out_err:
1278         while (i-- > 0) {
1279                 if (bufs[i].value)
1280                         free(bufs[i].value);
1281                 bufs[i].length = 0;
1282         }
1283         return -1;
1284 }
1285
1286 /**
1287  * Creates a netstring in a gss_buffer_desc that consists of all
1288  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1289  * as binary as it can contain null characters.
1290  *
1291  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1292  * \param[in]           numbufs         Number of buffers in array
1293  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1294  *                                      netstring
1295  *
1296  * \return      -1      failure
1297  * \return      0       success
1298  */
1299 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1300                         gss_buffer_desc *ns)
1301 {
1302         unsigned char *ptr;
1303         int size = 0;
1304         int rc;
1305         int i;
1306
1307         /* size of string in decimal, string size, colon, and comma */
1308         for (i = 0; i < numbufs; i++) {
1309
1310                 if (bufs[i].length == 0)
1311                         size += 3;
1312                 else
1313                         size += ceil(log10(bufs[i].length + 1)) +
1314                                 bufs[i].length + 2;
1315         }
1316
1317         ns->length = size;
1318         ns->value = malloc(ns->length);
1319         if (!ns->value) {
1320                 ns->length = 0;
1321                 return -1;
1322         }
1323
1324         ptr = ns->value;
1325         for (i = 0; i < numbufs; i++) {
1326                 /* size */
1327                 rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
1328                 ptr += rc;
1329
1330                 /* contents */
1331                 memcpy(ptr, bufs[i].value, bufs[i].length);
1332                 ptr += bufs[i].length;
1333
1334                 /* delimeter */
1335                 *ptr++ = ',';
1336
1337                 size -= bufs[i].length + rc + 1;
1338
1339                 /* should not happen */
1340                 if (size < 0)
1341                         abort();
1342         }
1343
1344         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1345         return 0;
1346 }