Whamcloud - gitweb
LU-10937 mgc: restore mgc binding for sptlrpc
[fs/lustre-release.git] / lustre / utils / gss / sk_utils.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2016, 2017, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  */
29
30 #include <fcntl.h>
31 #include <limits.h>
32 #include <math.h>
33 #include <string.h>
34 #include <stdbool.h>
35 #include <unistd.h>
36 #include <openssl/dh.h>
37 #include <openssl/engine.h>
38 #include <openssl/err.h>
39 #include <openssl/hmac.h>
40 #include <sys/types.h>
41 #include <sys/stat.h>
42
43 #include "sk_utils.h"
44 #include "write_bytes.h"
45
46 #define SK_PBKDF2_ITERATIONS 10000
47
48 static struct sk_crypt_type sk_crypt_types[] = {
49         [SK_CRYPT_AES256_CTR] = {
50                 .sct_name = "ctr(aes)",
51                 .sct_bytes = 32,
52         },
53 };
54
55 static struct sk_hmac_type sk_hmac_types[] = {
56         [SK_HMAC_SHA256] = {
57                 .sht_name = "hmac(sha256)",
58                 .sht_bytes = 32,
59         },
60         [SK_HMAC_SHA512] = {
61                 .sht_name = "hmac(sha512)",
62                 .sht_bytes = 64,
63         },
64 };
65
66 #ifdef _NEW_BUILD_
67 # include "lgss_utils.h"
68 #else
69 # include "gss_util.h"
70 # include "gss_oids.h"
71 # include "err_util.h"
72 #endif
73
74 #ifdef _ERR_UTIL_H_
75 /**
76  * Initializes logging
77  * \param[in]   program         Program name to output
78  * \param[in]   verbose         Verbose flag
79  * \param[in]   fg              Whether or not to run in foreground
80  *
81  */
82 void sk_init_logging(char *program, int verbose, int fg)
83 {
84         initerr(program, verbose, fg);
85 }
86 #endif
87
88 /**
89  * Loads the key from \a filename and returns the struct sk_keyfile_config.
90  * It should be freed by the caller.
91  *
92  * \param[in]   filename                Disk or key payload data
93  *
94  * \return      sk_keyfile_config       sucess
95  * \return      NULL                    failure
96  */
97 struct sk_keyfile_config *sk_read_file(char *filename)
98 {
99         struct sk_keyfile_config *config;
100         char *ptr;
101         size_t rc;
102         size_t remain;
103         int fd;
104
105         config = malloc(sizeof(*config));
106         if (!config) {
107                 printerr(0, "Failed to allocate memory for config\n");
108                 return NULL;
109         }
110
111         /* allow standard input override */
112         if (strcmp(filename, "-") == 0)
113                 fd = STDIN_FILENO;
114         else
115                 fd = open(filename, O_RDONLY);
116
117         if (fd == -1) {
118                 printerr(0, "Error opening key file '%s': %s\n", filename,
119                          strerror(errno));
120                 goto out_free;
121         } else if (fd != STDIN_FILENO) {
122                 struct stat st;
123
124                 rc = fstat(fd, &st);
125                 if (rc == 0 && (st.st_mode & ~(S_IFREG | 0600)))
126                         fprintf(stderr, "warning: "
127                                 "secret key '%s' has insecure file mode %#o\n",
128                                 filename, st.st_mode);
129         }
130
131         ptr = (char *)config;
132         remain = sizeof(*config);
133         while (remain > 0) {
134                 rc = read(fd, ptr, remain);
135                 if (rc == -1) {
136                         if (errno == EINTR)
137                                 continue;
138                         printerr(0, "read() failed on %s: %s\n", filename,
139                                  strerror(errno));
140                         goto out_close;
141                 } else if (rc == 0) {
142                         printerr(0, "File %s does not have a complete key\n",
143                                  filename);
144                         goto out_close;
145                 }
146                 ptr += rc;
147                 remain -= rc;
148         }
149
150         if (fd != STDIN_FILENO)
151                 close(fd);
152         sk_config_disk_to_cpu(config);
153         return config;
154
155 out_close:
156         close(fd);
157 out_free:
158         free(config);
159         return NULL;
160 }
161
162 /**
163  * Checks if a key matching \a description is found in the keyring for
164  * logging purposes and then attempts to load \a payload of \a psize into a key
165  * with \a description.
166  *
167  * \param[in]   payload         Key payload
168  * \param[in]   psize           Payload size
169  * \param[in]   description     Description used for key in keyring
170  *
171  * \return      0       sucess
172  * \return      -1      failure
173  */
174 static key_serial_t sk_load_key(const struct sk_keyfile_config *skc,
175                                 const char *description)
176 {
177         struct sk_keyfile_config payload;
178         key_serial_t key;
179
180         memcpy(&payload, skc, sizeof(*skc));
181
182         /* In the keyring use the disk layout so keyctl pipe can be used */
183         sk_config_cpu_to_disk(&payload);
184
185         /* Check to see if a key is already loaded matching description */
186         key = keyctl_search(KEY_SPEC_USER_KEYRING, "user", description, 0);
187         if (key != -1)
188                 printerr(2, "Key %d found in session keyring, replacing\n",
189                          key);
190
191         key = add_key("user", description, &payload, sizeof(payload),
192                       KEY_SPEC_USER_KEYRING);
193         if (key != -1)
194                 printerr(2, "Added key %d with description %s\n", key,
195                          description);
196         else
197                 printerr(0, "Failed to add key with %s\n", description);
198
199         return key;
200 }
201
202 /**
203  * Reads the key from \a path, verifies it and loads into the session keyring
204  * using a description determined by the the \a type.  Existing keys with the
205  * same description are replaced.
206  *
207  * \param[in]   path    Path to key file
208  * \param[in]   type    Type of key to load which determines the description
209  *
210  * \return      0       sucess
211  * \return      -1      failure
212  */
213 int sk_load_keyfile(char *path)
214 {
215         struct sk_keyfile_config *config;
216         char description[SK_DESCRIPTION_SIZE + 1];
217         struct stat buf;
218         int i;
219         int rc;
220         int rc2 = -1;
221
222         rc = stat(path, &buf);
223         if (rc == -1) {
224                 printerr(0, "stat() failed for file %s: %s\n", path,
225                          strerror(errno));
226                 return rc2;
227         }
228
229         config = sk_read_file(path);
230         if (!config)
231                 return rc2;
232
233         /* Similar to ssh, require adequate care of key files */
234         if (buf.st_mode & (S_IRGRP | S_IWGRP | S_IWOTH | S_IXOTH)) {
235                 printerr(0, "Shared key files must be read/writeable only by "
236                          "owner\n");
237                 return -1;
238         }
239
240         if (sk_validate_config(config))
241                 goto out;
242
243         /* The server side can have multiple key files per file system so
244          * the nodemap name is appended to the key description to uniquely
245          * identify it */
246         if (config->skc_type & SK_TYPE_MGS) {
247                 /* Any key can be an MGS key as long as we are told to use it */
248                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:MGS:%s",
249                               config->skc_nodemap);
250                 if (rc >= SK_DESCRIPTION_SIZE)
251                         goto out;
252                 if (sk_load_key(config, description) == -1)
253                         goto out;
254         }
255         if (config->skc_type & SK_TYPE_SERVER) {
256                 /* Server keys need to have the file system name in the key */
257                 if (!config->skc_fsname) {
258                         printerr(0, "Key configuration has no file system "
259                                  "attribute.  Can't load as server type\n");
260                         goto out;
261                 }
262                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s:%s",
263                               config->skc_fsname, config->skc_nodemap);
264                 if (rc >= SK_DESCRIPTION_SIZE)
265                         goto out;
266                 if (sk_load_key(config, description) == -1)
267                         goto out;
268         }
269         if (config->skc_type & SK_TYPE_CLIENT) {
270                 /* Load client file system key */
271                 if (config->skc_fsname) {
272                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
273                                       "lustre:%s", config->skc_fsname);
274                         if (rc >= SK_DESCRIPTION_SIZE)
275                                 goto out;
276                         if (sk_load_key(config, description) == -1)
277                                 goto out;
278                 }
279
280                 /* Load client MGC keys */
281                 for (i = 0; i < MAX_MGSNIDS; i++) {
282                         if (config->skc_mgsnids[i] == LNET_NID_ANY)
283                                 continue;
284                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
285                                       "lustre:MGC%s",
286                                       libcfs_nid2str(config->skc_mgsnids[i]));
287                         if (rc >= SK_DESCRIPTION_SIZE)
288                                 goto out;
289                         if (sk_load_key(config, description) == -1)
290                                 goto out;
291                 }
292         }
293
294         rc2 = 0;
295
296 out:
297         free(config);
298         return rc2;
299 }
300
301 /**
302  * Byte swaps config from cpu format to disk
303  *
304  * \param[in,out]       config          sk_keyfile_config to swap
305  */
306 void sk_config_cpu_to_disk(struct sk_keyfile_config *config)
307 {
308         int i;
309
310         if (!config)
311                 return;
312
313         config->skc_version = htobe32(config->skc_version);
314         config->skc_hmac_alg = htobe16(config->skc_hmac_alg);
315         config->skc_crypt_alg = htobe16(config->skc_crypt_alg);
316         config->skc_expire = htobe32(config->skc_expire);
317         config->skc_shared_keylen = htobe32(config->skc_shared_keylen);
318         config->skc_prime_bits = htobe32(config->skc_prime_bits);
319
320         for (i = 0; i < MAX_MGSNIDS; i++)
321                 config->skc_mgsnids[i] = htobe64(config->skc_mgsnids[i]);
322
323         return;
324 }
325
326 /**
327  * Byte swaps config from disk format to cpu
328  *
329  * \param[in,out]       config          sk_keyfile_config to swap
330  */
331 void sk_config_disk_to_cpu(struct sk_keyfile_config *config)
332 {
333         int i;
334
335         if (!config)
336                 return;
337
338         config->skc_version = be32toh(config->skc_version);
339         config->skc_hmac_alg = be16toh(config->skc_hmac_alg);
340         config->skc_crypt_alg = be16toh(config->skc_crypt_alg);
341         config->skc_expire = be32toh(config->skc_expire);
342         config->skc_shared_keylen = be32toh(config->skc_shared_keylen);
343         config->skc_prime_bits = be32toh(config->skc_prime_bits);
344
345         for (i = 0; i < MAX_MGSNIDS; i++)
346                 config->skc_mgsnids[i] = be64toh(config->skc_mgsnids[i]);
347
348         return;
349 }
350
351 /**
352  * Verifies the on key payload format is valid
353  *
354  * \param[in]   config          sk_keyfile_config
355  *
356  * \return      -1      failure
357  * \return      0       success
358  */
359 int sk_validate_config(const struct sk_keyfile_config *config)
360 {
361         int i;
362
363         if (!config) {
364                 printerr(0, "Null configuration passed\n");
365                 return -1;
366         }
367         if (config->skc_version != SK_CONF_VERSION) {
368                 printerr(0, "Invalid version\n");
369                 return -1;
370         }
371         if (config->skc_hmac_alg >= SK_HMAC_MAX) {
372                 printerr(0, "Invalid HMAC algorithm\n");
373                 return -1;
374         }
375         if (config->skc_crypt_alg >= SK_CRYPT_MAX) {
376                 printerr(0, "Invalid crypt algorithm\n");
377                 return -1;
378         }
379         if (config->skc_expire < 60 || config->skc_expire > INT_MAX) {
380                 /* Try to limit key expiration to some reasonable minimum and
381                  * also prevent values over INT_MAX because there appears
382                  * to be a type conversion issue */
383                 printerr(0, "Invalid expiration time should be between %d "
384                          "and %d\n", 60, INT_MAX);
385                 return -1;
386         }
387         if (config->skc_prime_bits % 8 != 0 ||
388             config->skc_prime_bits > SK_MAX_P_BYTES * 8) {
389                 printerr(0, "Invalid session key length must be a multiple of 8"
390                          " and less then %d bits\n",
391                          SK_MAX_P_BYTES * 8);
392                 return -1;
393         }
394         if (config->skc_shared_keylen % 8 != 0 ||
395             config->skc_shared_keylen > SK_MAX_KEYLEN_BYTES * 8){
396                 printerr(0, "Invalid shared key max length must be a multiple "
397                          "of 8 and less then %d bits\n",
398                          SK_MAX_KEYLEN_BYTES * 8);
399                 return -1;
400         }
401
402         /* Check for terminating nulls on strings */
403         for (i = 0; i < sizeof(config->skc_fsname) &&
404              config->skc_fsname[i] != '\0';  i++)
405                 ; /* empty loop */
406         if (i == sizeof(config->skc_fsname)) {
407                 printerr(0, "File system name not null terminated\n");
408                 return -1;
409         }
410
411         for (i = 0; i < sizeof(config->skc_nodemap) &&
412              config->skc_nodemap[i] != '\0';  i++)
413                 ; /* empty loop */
414         if (i == sizeof(config->skc_nodemap)) {
415                 printerr(0, "Nodemap name not null terminated\n");
416                 return -1;
417         }
418
419         if (config->skc_type == SK_TYPE_INVALID) {
420                 printerr(0, "Invalid key type\n");
421                 return -1;
422         }
423
424         return 0;
425 }
426
427 /**
428  * Hashes \a string and places the hash in \a hash
429  * at \a hash
430  *
431  * \param[in]           string          Null terminated string to hash
432  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
433  * \param[in,out]       hash            gss_buffer_desc to hold the result
434  *
435  * \return      -1      failure
436  * \return      0       success
437  */
438 static int sk_hash_string(const char *string, const EVP_MD *hash_alg,
439                           gss_buffer_desc *hash)
440 {
441         EVP_MD_CTX *ctx = EVP_MD_CTX_create();
442         size_t len = strlen(string);
443         unsigned int hashlen;
444
445         if (!hash->value || hash->length < EVP_MD_size(hash_alg))
446                 goto out_err;
447         if (!EVP_DigestInit_ex(ctx, hash_alg, NULL))
448                 goto out_err;
449         if (!EVP_DigestUpdate(ctx, string, len))
450                 goto out_err;
451         if (!EVP_DigestFinal_ex(ctx, hash->value, &hashlen))
452                 goto out_err;
453
454         EVP_MD_CTX_destroy(ctx);
455         hash->length = hashlen;
456         return 0;
457
458 out_err:
459         EVP_MD_CTX_destroy(ctx);
460         return -1;
461 }
462
463 /**
464  * Hashes \a string and verifies the resulting hash matches the value
465  * in \a current_hash
466  *
467  * \param[in]           string          Null terminated string to hash
468  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
469  * \param[in,out]       current_hash    gss_buffer_desc to compare to
470  *
471  * \return      gss error       failure
472  * \return      GSS_S_COMPLETE  success
473  */
474 uint32_t sk_verify_hash(const char *string, const EVP_MD *hash_alg,
475                         const gss_buffer_desc *current_hash)
476 {
477         gss_buffer_desc hash;
478         unsigned char hashbuf[EVP_MAX_MD_SIZE];
479
480         hash.value = hashbuf;
481         hash.length = sizeof(hashbuf);
482
483         if (sk_hash_string(string, hash_alg, &hash))
484                 return GSS_S_FAILURE;
485         if (current_hash->length != hash.length)
486                 return GSS_S_DEFECTIVE_TOKEN;
487         if (memcmp(current_hash->value, hash.value, hash.length))
488                 return GSS_S_BAD_SIG;
489
490         return GSS_S_COMPLETE;
491 }
492
493 static inline int sk_config_has_mgsnid(struct sk_keyfile_config *config,
494                                        const char *mgsnid)
495 {
496         lnet_nid_t nid;
497         int i;
498
499         nid = libcfs_str2nid(mgsnid);
500         if (nid == LNET_NID_ANY)
501                 return 0;
502
503         for (i = 0; i < MAX_MGSNIDS; i++)
504                 if  (config->skc_mgsnids[i] == nid)
505                         return 1;
506         return 0;
507 }
508
509 /**
510  * Create an sk_cred structure populated with initial configuration info and the
511  * key.  \a tgt and \a nodemap are used in determining the expected key
512  * description so the key can be found by searching the keyring.
513  * This is done because there is no easy way to pass keys from the mount command
514  * all the way to the request_key call.  In addition any keys can be dynamically
515  * added to the keyrings and still found.  The keyring that needs to be used
516  * must be the session keyring.
517  *
518  * \param[in]   tgt             Target file system
519  * \param[in]   nodemap         Cluster name for the key.  This correlates to
520  *                              the nodemap name and is used by the server side.
521  *                              For the client this will be NULL.
522  * \param[in]   flags           Flags for the credentials
523  *
524  * \return      sk_cred Allocated struct sk_cred on success
525  * \return      NULL    failure
526  */
527 struct sk_cred *sk_create_cred(const char *tgt, const char *nodemap,
528                                const uint32_t flags)
529 {
530         struct sk_keyfile_config *config;
531         struct sk_kernel_ctx *kctx;
532         struct sk_cred *skc = NULL;
533         char description[SK_DESCRIPTION_SIZE + 1];
534         char fsname[MTI_NAME_MAXLEN + 1];
535         const char *mgsnid = NULL;
536         char *ptr;
537         long sk_key;
538         int keylen;
539         int len;
540         int rc;
541
542         printerr(2, "Creating credentials for target: %s with nodemap: %s\n",
543                  tgt, nodemap);
544
545         memset(description, 0, sizeof(description));
546         memset(fsname, 0, sizeof(fsname));
547
548         /* extract the file system name from target */
549         ptr = index(tgt, '-');
550         if (!ptr) {
551                 len = strlen(tgt);
552
553                 /* This must be an MGC target */
554                 if (strncmp(tgt, "MGC", 3) || len <= 3) {
555                         printerr(0, "Invalid target name\n");
556                         return NULL;
557                 }
558                 mgsnid = tgt + 3;
559         } else {
560                 len = ptr - tgt;
561         }
562
563         if (len > MTI_NAME_MAXLEN) {
564                 printerr(0, "Invalid target name\n");
565                 return NULL;
566         }
567         memcpy(fsname, tgt, len);
568
569         if (nodemap) {
570                 if (mgsnid)
571                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
572                                       "lustre:MGS:%s", nodemap);
573                 else
574                         rc = snprintf(description, SK_DESCRIPTION_SIZE,
575                                       "lustre:%s:%s", fsname, nodemap);
576         } else {
577                 rc = snprintf(description, SK_DESCRIPTION_SIZE, "lustre:%s",
578                               fsname);
579         }
580
581         if (rc >= SK_DESCRIPTION_SIZE) {
582                 printerr(0, "Invalid key description\n");
583                 return NULL;
584         }
585
586         /* It may be a good idea to move Lustre keys to the gss_keyring
587          * (lgssc) type so that they expire when Lustre modules are removed.
588          * Unfortunately it can't be done at mount time because the mount
589          * syscall could trigger the Lustre modules to load and until that
590          * point we don't have a lgssc key type.
591          *
592          * TODO: Query the community for a consensus here  */
593         printerr(2, "Searching for key with description: %s\n", description);
594         sk_key = keyctl_search(KEY_SPEC_USER_KEYRING, "user",
595                                description, 0);
596         if (sk_key == -1) {
597                 printerr(1, "No key found for %s\n", description);
598                 return NULL;
599         }
600
601         keylen = keyctl_read_alloc(sk_key, (void **)&config);
602         if (keylen == -1) {
603                 printerr(0, "keyctl_read() failed for key %ld: %s\n", sk_key,
604                          strerror(errno));
605                 return NULL;
606         } else if (keylen != sizeof(*config)) {
607                 printerr(0, "Unexpected key size: %d returned for key %ld, "
608                          "expected %zu bytes\n",
609                          keylen, sk_key, sizeof(*config));
610                 goto out_err;
611         }
612
613         sk_config_disk_to_cpu(config);
614
615         if (sk_validate_config(config)) {
616                 printerr(0, "Invalid key configuration for key: %ld\n", sk_key);
617                 goto out_err;
618         }
619
620         if (mgsnid && !sk_config_has_mgsnid(config, mgsnid)) {
621                 printerr(0, "Target name does not match key's MGS NIDs\n");
622                 goto out_err;
623         }
624
625         if (!mgsnid && strcmp(fsname, config->skc_fsname)) {
626                 printerr(0, "Target name does not match key's file system\n");
627                 goto out_err;
628         }
629
630         skc = malloc(sizeof(*skc));
631         if (!skc) {
632                 printerr(0, "Failed to allocate memory for sk_cred\n");
633                 goto out_err;
634         }
635
636         /* this initializes all gss_buffer_desc to empty as well */
637         memset(skc, 0, sizeof(*skc));
638
639         skc->sc_flags = flags;
640         skc->sc_tgt.length = strlen(tgt) + 1;
641         skc->sc_tgt.value = malloc(skc->sc_tgt.length);
642         if (!skc->sc_tgt.value) {
643                 printerr(0, "Failed to allocate memory for target\n");
644                 goto out_err;
645         }
646         memcpy(skc->sc_tgt.value, tgt, skc->sc_tgt.length);
647
648         skc->sc_nodemap_hash.length = EVP_MD_size(EVP_sha256());
649         skc->sc_nodemap_hash.value = malloc(skc->sc_nodemap_hash.length);
650         if (!skc->sc_nodemap_hash.value) {
651                 printerr(0, "Failed to allocate memory for nodemap hash\n");
652                 goto out_err;
653         }
654
655         if (sk_hash_string(config->skc_nodemap, EVP_sha256(),
656                            &skc->sc_nodemap_hash)) {
657                 printerr(0, "Failed to generate hash for nodemap name\n");
658                 goto out_err;
659         }
660
661         kctx = &skc->sc_kctx;
662         kctx->skc_version = config->skc_version;
663         kctx->skc_hmac_alg = config->skc_hmac_alg;
664         kctx->skc_crypt_alg = config->skc_crypt_alg;
665         kctx->skc_expire = config->skc_expire;
666
667         /* key payload format is in bits, convert to bytes */
668         kctx->skc_shared_key.length = config->skc_shared_keylen / 8;
669         kctx->skc_shared_key.value = malloc(kctx->skc_shared_key.length);
670         if (!kctx->skc_shared_key.value) {
671                 printerr(0, "Failed to allocate memory for shared key\n");
672                 goto out_err;
673         }
674         memcpy(kctx->skc_shared_key.value, config->skc_shared_key,
675                kctx->skc_shared_key.length);
676
677         skc->sc_p.length = config->skc_prime_bits / 8;
678         skc->sc_p.value = malloc(skc->sc_p.length);
679         if (!skc->sc_p.value) {
680                 printerr(0, "Failed to allocate p\n");
681                 goto out_err;
682         }
683         memcpy(skc->sc_p.value, config->skc_p, skc->sc_p.length);
684
685         free(config);
686
687         return skc;
688
689 out_err:
690         sk_free_cred(skc);
691
692         free(config);
693         return NULL;
694 }
695
696 /**
697  * Populates the DH parameters for the DHKE
698  *
699  * \param[in,out]       skc             Shared key credentials structure to
700  *                                      populate with DH parameters
701  *
702  * \retval      GSS_S_COMPLETE  success
703  * \retval      GSS_S_FAILURE   failure
704  */
705 uint32_t sk_gen_params(struct sk_cred *skc)
706 {
707         uint32_t random;
708         int rc;
709
710         /* Random value used by both the request and response as part of the
711          * key binding material.  This also should ensure we have unqiue
712          * tokens that are sent to the remote server which is important because
713          * the token is hashed for the sunrpc cache lookups and a failure there
714          * would cause connection attempts to fail indefinitely due to the large
715          * timeout value on the server side */
716         if (RAND_bytes((unsigned char *)&random, sizeof(random)) != 1) {
717                 printerr(0, "Failed to get data for random parameter: %s\n",
718                          ERR_error_string(ERR_get_error(), NULL));
719                 return GSS_S_FAILURE;
720         }
721
722         /* The random value will always be used in byte range operations
723          * so we keep it as big endian from this point on */
724         skc->sc_kctx.skc_host_random = random;
725
726         /* Populate DH parameters */
727         skc->sc_params = DH_new();
728         if (!skc->sc_params) {
729                 printerr(0, "Failed to allocate DH\n");
730                 return GSS_S_FAILURE;
731         }
732
733         skc->sc_params->p = BN_bin2bn(skc->sc_p.value, skc->sc_p.length, NULL);
734         if (!skc->sc_params->p) {
735                 printerr(0, "Failed to convert binary to BIGNUM\n");
736                 return GSS_S_FAILURE;
737         }
738
739         /* We use a static generator for shared key */
740         skc->sc_params->g = BN_new();
741         if (!skc->sc_params->g) {
742                 printerr(0, "Failed to allocate new BIGNUM\n");
743                 return GSS_S_FAILURE;
744         }
745         if (BN_set_word(skc->sc_params->g, SK_GENERATOR) != 1) {
746                 printerr(0, "Failed to set g value for DH params\n");
747                 return GSS_S_FAILURE;
748         }
749
750         /* Verify that we have a safe prime and valid generator */
751         if (DH_check(skc->sc_params, &rc) != 1) {
752                 printerr(0, "DH_check() failed: %d\n", rc);
753                 return GSS_S_FAILURE;
754         } else if (rc) {
755                 printerr(0, "DH_check() returned error codes: 0x%x\n", rc);
756                 return GSS_S_FAILURE;
757         }
758
759         if (DH_generate_key(skc->sc_params) != 1) {
760                 printerr(0, "Failed to generate public DH key: %s\n",
761                          ERR_error_string(ERR_get_error(), NULL));
762                 return GSS_S_FAILURE;
763         }
764
765         skc->sc_pub_key.length = BN_num_bytes(skc->sc_params->pub_key);
766         skc->sc_pub_key.value = malloc(skc->sc_pub_key.length);
767         if (!skc->sc_pub_key.value) {
768                 printerr(0, "Failed to allocate memory for public key\n");
769                 return GSS_S_FAILURE;
770         }
771
772         BN_bn2bin(skc->sc_params->pub_key, skc->sc_pub_key.value);
773
774         return GSS_S_COMPLETE;
775 }
776
777 /**
778  * Convert SK hash algorithm into openssl message digest
779  *
780  * \param[in,out]       alg             SK hash algorithm
781  *
782  * \retval              EVP_MD
783  */
784 static inline const EVP_MD *sk_hash_to_evp_md(enum sk_hmac_alg alg)
785 {
786         switch (alg) {
787         case SK_HMAC_SHA256:
788                 return EVP_sha256();
789         case SK_HMAC_SHA512:
790                 return EVP_sha512();
791         default:
792                 return EVP_md_null();
793         }
794 }
795
796 /**
797  * Signs (via HMAC) the parameters used only in the key initialization protocol.
798  *
799  * \param[in]           key             Key to use for HMAC
800  * \param[in]           bufs            Array of gss_buffer_desc to generate
801  *                                      HMAC for
802  * \param[in]           numbufs         Number of buffers in array
803  * \param[in]           hash_alg        OpenSSL EVP_MD to use for hash
804  * \param[in,out]       hmac            HMAC of buffers is allocated and placed
805  *                                      in this gss_buffer_desc.  Caller must
806  *                                      free this.
807  *
808  * \retval      0       success
809  * \retval      -1      failure
810  */
811 int sk_sign_bufs(gss_buffer_desc *key, gss_buffer_desc *bufs, const int numbufs,
812                  const EVP_MD *hash_alg, gss_buffer_desc *hmac)
813 {
814         HMAC_CTX hctx;
815         unsigned int hashlen = EVP_MD_size(hash_alg);
816         int i;
817         int rc = -1;
818
819         if (hash_alg == EVP_md_null()) {
820                 printerr(0, "Invalid hash algorithm\n");
821                 return -1;
822         }
823
824         HMAC_CTX_init(&hctx);
825
826         hmac->length = hashlen;
827         hmac->value = malloc(hashlen);
828         if (!hmac->value) {
829                 printerr(0, "Failed to allocate memory for HMAC\n");
830                 goto out;
831         }
832
833         if (HMAC_Init_ex(&hctx, key->value, key->length, hash_alg, NULL) != 1) {
834                 printerr(0, "Failed to init HMAC\n");
835                 goto out;
836         }
837
838         for (i = 0; i < numbufs; i++) {
839                 if (HMAC_Update(&hctx, bufs[i].value, bufs[i].length) != 1) {
840                         printerr(0, "Failed to update HMAC\n");
841                         goto out;
842                 }
843         }
844
845         /* The result gets populated in hmac */
846         if (HMAC_Final(&hctx, hmac->value, &hashlen) != 1) {
847                 printerr(0, "Failed to finalize HMAC\n");
848                 goto out;
849         }
850
851         if (hmac->length != hashlen) {
852                 printerr(0, "HMAC size does not match expected\n");
853                 goto out;
854         }
855
856         rc = 0;
857 out:
858         HMAC_CTX_cleanup(&hctx);
859         return rc;
860 }
861
862 /**
863  * Generates an HMAC for gss_buffer_desc array in \a bufs of \a numbufs
864  * and verifies against \a hmac.
865  *
866  * \param[in]   skc             Shared key credentials
867  * \param[in]   bufs            Array of gss_buffer_desc to generate HMAC for
868  * \param[in]   numbufs         Number of buffers in array
869  * \param[in]   hash_alg        OpenSSL EVP_MD to use for hash
870  * \param[in]   hmac            HMAC to verify against
871  *
872  * \retval      GSS_S_COMPLETE  success (match)
873  * \retval      gss error       failure
874  */
875 uint32_t sk_verify_hmac(struct sk_cred *skc, gss_buffer_desc *bufs,
876                         const int numbufs, const EVP_MD *hash_alg,
877                         gss_buffer_desc *hmac)
878 {
879         gss_buffer_desc bufs_hmac;
880         int rc;
881
882         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, numbufs, hash_alg,
883                          &bufs_hmac)) {
884                 printerr(0, "Failed to sign buffers to verify HMAC\n");
885                 if (bufs_hmac.value)
886                         free(bufs_hmac.value);
887                 return GSS_S_FAILURE;
888         }
889
890         if (hmac->length != bufs_hmac.length) {
891                 printerr(0, "Invalid HMAC size\n");
892                 free(bufs_hmac.value);
893                 return GSS_S_BAD_SIG;
894         }
895
896         rc = memcmp(hmac->value, bufs_hmac.value, bufs_hmac.length);
897         free(bufs_hmac.value);
898
899         if (rc)
900                 return GSS_S_BAD_SIG;
901
902         return GSS_S_COMPLETE;
903 }
904
905 /**
906  * Cleanup an sk_cred freeing any resources
907  *
908  * \param[in,out]       skc     Shared key credentials to free
909  */
910 void sk_free_cred(struct sk_cred *skc)
911 {
912         if (!skc)
913                 return;
914
915         if (skc->sc_p.value)
916                 free(skc->sc_p.value);
917         if (skc->sc_pub_key.value)
918                 free(skc->sc_pub_key.value);
919         if (skc->sc_tgt.value)
920                 free(skc->sc_tgt.value);
921         if (skc->sc_nodemap_hash.value)
922                 free(skc->sc_nodemap_hash.value);
923         if (skc->sc_hmac.value)
924                 free(skc->sc_hmac.value);
925
926         /* Overwrite keys and IV before freeing */
927         if (skc->sc_dh_shared_key.value) {
928                 memset(skc->sc_dh_shared_key.value, 0,
929                        skc->sc_dh_shared_key.length);
930                 free(skc->sc_dh_shared_key.value);
931         }
932         if (skc->sc_kctx.skc_hmac_key.value) {
933                 memset(skc->sc_kctx.skc_hmac_key.value, 0,
934                        skc->sc_kctx.skc_hmac_key.length);
935                 free(skc->sc_kctx.skc_hmac_key.value);
936         }
937         if (skc->sc_kctx.skc_encrypt_key.value) {
938                 memset(skc->sc_kctx.skc_encrypt_key.value, 0,
939                        skc->sc_kctx.skc_encrypt_key.length);
940                 free(skc->sc_kctx.skc_encrypt_key.value);
941         }
942         if (skc->sc_kctx.skc_shared_key.value) {
943                 memset(skc->sc_kctx.skc_shared_key.value, 0,
944                        skc->sc_kctx.skc_shared_key.length);
945                 free(skc->sc_kctx.skc_shared_key.value);
946         }
947         if (skc->sc_kctx.skc_session_key.value) {
948                 memset(skc->sc_kctx.skc_session_key.value, 0,
949                        skc->sc_kctx.skc_session_key.length);
950                 free(skc->sc_kctx.skc_session_key.value);
951         }
952
953         if (skc->sc_params)
954                 DH_free(skc->sc_params);
955
956         free(skc);
957         skc = NULL;
958 }
959
960 /* This function handles key derivation using the hash algorithm specified in
961  * \a hash_alg, buffers in \a key_binding_bufs, and original key in
962  * \a origin_key to produce a \a derived_key.  The first element of the
963  * key_binding_bufs array is reserved for the counter used in the KDF.  The
964  * derived key in \a derived_key could differ in size from \a origin_key and
965  * must be populated with the expected size and a valid buffer to hold the
966  * contents.
967  *
968  * If the derived key size is greater than the HMAC algorithm size it will be
969  * a done using several iterations of a counter and the key binding bufs.
970  *
971  * If the size is smaller it will take copy the first N bytes necessary to
972  * fill the derived key. */
973 int sk_kdf(gss_buffer_desc *derived_key , gss_buffer_desc *origin_key,
974            gss_buffer_desc *key_binding_bufs, int numbufs, int hmac_alg)
975 {
976         size_t remain;
977         size_t bytes;
978         uint32_t counter;
979         char *keydata;
980         gss_buffer_desc tmp_hash;
981         int i;
982         int rc;
983
984         if (numbufs < 1)
985                 return -EINVAL;
986
987         /* Use a counter as the first buffer followed by the key binding
988          * buffers in the event we need more than one a single cycle to
989          * produced a symmetric key large enough in size */
990         key_binding_bufs[0].value = &counter;
991         key_binding_bufs[0].length = sizeof(counter);
992
993         remain = derived_key->length;
994         keydata = derived_key->value;
995         i = 0;
996         while (remain > 0) {
997                 counter = htobe32(i++);
998                 rc = sk_sign_bufs(origin_key, key_binding_bufs, numbufs,
999                                   sk_hash_to_evp_md(hmac_alg), &tmp_hash);
1000                 if (rc) {
1001                         if (tmp_hash.value)
1002                                 free(tmp_hash.value);
1003                         return rc;
1004                 }
1005
1006                 if (sk_hmac_types[hmac_alg].sht_bytes != tmp_hash.length) {
1007                         free(tmp_hash.value);
1008                         return -EINVAL;
1009                 }
1010
1011                 bytes = (remain < tmp_hash.length) ? remain : tmp_hash.length;
1012                 memcpy(keydata, tmp_hash.value, bytes);
1013                 free(tmp_hash.value);
1014                 remain -= bytes;
1015                 keydata += bytes;
1016         }
1017
1018         return 0;
1019 }
1020
1021 /* Populates the sk_cred's session_key using the a Key Derviation Function (KDF)
1022  * based on the recommendations in NIST Special Publication SP 800-56B Rev 1
1023  * (Sep 2014) Section 5.5.1
1024  *
1025  * \param[in,out]       skc             Shared key credentials structure with
1026  *
1027  * \return      -1              failure
1028  * \return      0               success
1029  */
1030 int sk_session_kdf(struct sk_cred *skc, lnet_nid_t client_nid,
1031                    gss_buffer_desc *client_token, gss_buffer_desc *server_token)
1032 {
1033         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1034         gss_buffer_desc *session_key = &kctx->skc_session_key;
1035         gss_buffer_desc bufs[5];
1036         int rc = -1;
1037
1038         session_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1039         session_key->value = malloc(session_key->length);
1040         if (!session_key->value) {
1041                 printerr(0, "Failed to allocate memory for session key\n");
1042                 return rc;
1043         }
1044
1045         /* Key binding info ordering
1046          * 1. Reserved for counter
1047          * 1. DH shared key
1048          * 2. Client's NIDs
1049          * 3. Client's token
1050          * 4. Server's token */
1051         bufs[0].value = NULL;
1052         bufs[0].length = 0;
1053         bufs[1] = skc->sc_dh_shared_key;
1054         bufs[2].value = &client_nid;
1055         bufs[2].length = sizeof(client_nid);
1056         bufs[3] = *client_token;
1057         bufs[4] = *server_token;
1058
1059         return sk_kdf(&kctx->skc_session_key, &kctx->skc_shared_key, bufs,
1060                       5, kctx->skc_hmac_alg);
1061 }
1062
1063 /* Uses the session key to create an HMAC key and encryption key.  In
1064  * integrity mode the session key used to generate the HMAC key uses
1065  * session information which is available on the wire but by creating
1066  * a session based HMAC key we can prevent potential replay as both the
1067  * client and server have random numbers used as part of the key creation.
1068  *
1069  * The keys used for integrity and privacy are formulated as below using
1070  * the session key that is the output of the key derivation function.  The
1071  * HMAC algorithm is determined by the shared key algorithm selected in the
1072  * key file.
1073  *
1074  * For ski mode:
1075  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1076  *
1077  * For skpi mode:
1078  * Session HMAC Key = PBKDF2("Integrity", KDF derived Session Key)
1079  * Session Encryption Key = PBKDF2("Encrypt", KDF derived Session Key)
1080  *
1081  * \param[in,out]       skc             Shared key credentials structure with
1082  *
1083  * \return      -1              failure
1084  * \return      0               success
1085  */
1086 int sk_compute_keys(struct sk_cred *skc)
1087 {
1088         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1089         gss_buffer_desc *session_key = &kctx->skc_session_key;
1090         gss_buffer_desc *hmac_key = &kctx->skc_hmac_key;
1091         gss_buffer_desc *encrypt_key = &kctx->skc_encrypt_key;
1092         char *encrypt = "Encrypt";
1093         char *integrity = "Integrity";
1094         int rc;
1095
1096         hmac_key->length = sk_hmac_types[kctx->skc_hmac_alg].sht_bytes;
1097         hmac_key->value = malloc(hmac_key->length);
1098         if (!hmac_key->value)
1099                 return -ENOMEM;
1100
1101         rc = PKCS5_PBKDF2_HMAC(integrity, -1, session_key->value,
1102                                session_key->length, SK_PBKDF2_ITERATIONS,
1103                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1104                                hmac_key->length, hmac_key->value);
1105         if (rc == 0)
1106                 return -EINVAL;
1107
1108         /* Encryption key is only populated in privacy mode */
1109         if ((skc->sc_flags & LGSS_SVC_PRIV) == 0)
1110                 return 0;
1111
1112         encrypt_key->length = sk_crypt_types[kctx->skc_crypt_alg].sct_bytes;
1113         encrypt_key->value = malloc(encrypt_key->length);
1114         if (!encrypt_key->value)
1115                 return -ENOMEM;
1116
1117         rc = PKCS5_PBKDF2_HMAC(encrypt, -1, session_key->value,
1118                                session_key->length, SK_PBKDF2_ITERATIONS,
1119                                sk_hash_to_evp_md(kctx->skc_hmac_alg),
1120                                encrypt_key->length, encrypt_key->value);
1121         if (rc == 0)
1122                 return -EINVAL;
1123
1124         return 0;
1125 }
1126
1127 /**
1128  * Computes a session key based on the DH parameters from the host and its peer
1129  *
1130  * \param[in,out]       skc             Shared key credentials structure with
1131  *                                      the session key populated with the
1132  *                                      compute key
1133  * \param[in]           pub_key         Public key returned from peer in
1134  *                                      gss_buffer_desc
1135  * \return      gss error               failure
1136  * \return      GSS_S_COMPLETE          success
1137  */
1138 uint32_t sk_compute_dh_key(struct sk_cred *skc, const gss_buffer_desc *pub_key)
1139 {
1140         gss_buffer_desc *dh_shared = &skc->sc_dh_shared_key;
1141         BIGNUM *remote_pub_key;
1142         int status;
1143         uint32_t rc = GSS_S_FAILURE;
1144
1145         remote_pub_key = BN_bin2bn(pub_key->value, pub_key->length, NULL);
1146         if (!remote_pub_key) {
1147                 printerr(0, "Failed to convert binary to BIGNUM\n");
1148                 return rc;
1149         }
1150
1151         dh_shared->length = DH_size(skc->sc_params);
1152         dh_shared->value = malloc(dh_shared->length);
1153         if (!dh_shared->value) {
1154                 printerr(0, "Failed to allocate memory for computed shared "
1155                          "secret key\n");
1156                 goto out_err;
1157         }
1158
1159         /* This compute the shared key from the DHKE */
1160         status = DH_compute_key(dh_shared->value, remote_pub_key,
1161                                 skc->sc_params);
1162         if (status == -1) {
1163                 printerr(0, "DH_compute_key() failed: %s\n",
1164                          ERR_error_string(ERR_get_error(), NULL));
1165                 goto out_err;
1166         } else if (status < dh_shared->length) {
1167                 printerr(0, "DH_compute_key() returned a short key of %d "
1168                          "bytes, expected: %zu\n", status, dh_shared->length);
1169                 rc = GSS_S_DEFECTIVE_TOKEN;
1170                 goto out_err;
1171         }
1172
1173         rc = GSS_S_COMPLETE;
1174
1175 out_err:
1176         BN_free(remote_pub_key);
1177         return rc;
1178 }
1179
1180 /**
1181  * Creates a serialized buffer for the kernel in the order of struct
1182  * sk_kernel_ctx.
1183  *
1184  * \param[in,out]       skc             Shared key credentials structure
1185  * \param[in,out]       ctx_token       Serialized buffer for kernel.
1186  *                                      Caller must free this buffer.
1187  *
1188  * \return      0       success
1189  * \return      -1      failure
1190  */
1191 int sk_serialize_kctx(struct sk_cred *skc, gss_buffer_desc *ctx_token)
1192 {
1193         struct sk_kernel_ctx *kctx = &skc->sc_kctx;
1194         char *p, *end;
1195         size_t bufsize;
1196
1197         bufsize = sizeof(*kctx) + kctx->skc_hmac_key.length +
1198                   kctx->skc_encrypt_key.length;
1199
1200         ctx_token->value = malloc(bufsize);
1201         if (!ctx_token->value)
1202                 return -1;
1203         ctx_token->length = bufsize;
1204
1205         p = ctx_token->value;
1206         end = p + ctx_token->length;
1207
1208         if (WRITE_BYTES(&p, end, kctx->skc_version))
1209                 return -1;
1210         if (WRITE_BYTES(&p, end, kctx->skc_hmac_alg))
1211                 return -1;
1212         if (WRITE_BYTES(&p, end, kctx->skc_crypt_alg))
1213                 return -1;
1214         if (WRITE_BYTES(&p, end, kctx->skc_expire))
1215                 return -1;
1216         if (WRITE_BYTES(&p, end, kctx->skc_host_random))
1217                 return -1;
1218         if (WRITE_BYTES(&p, end, kctx->skc_peer_random))
1219                 return -1;
1220         if (write_buffer(&p, end, &kctx->skc_hmac_key))
1221                 return -1;
1222         if (write_buffer(&p, end, &kctx->skc_encrypt_key))
1223                 return -1;
1224
1225         printerr(2, "Serialized buffer of %zu bytes for kernel\n", bufsize);
1226
1227         return 0;
1228 }
1229
1230 /**
1231  * Decodes a netstring \a ns into array of gss_buffer_descs at \a bufs
1232  * up to \a numbufs.  Memory is allocated for each value and length
1233  * will be populated with the length
1234  *
1235  * \param[in,out]       bufs    Array of gss_buffer_descs
1236  * \param[in,out]       numbufs number of gss_buffer_desc in array
1237  * \param[in]           ns      netstring to decode
1238  *
1239  * \return      buffers populated       success
1240  * \return      -1                      failure
1241  */
1242 int sk_decode_netstring(gss_buffer_desc *bufs, int numbufs, gss_buffer_desc *ns)
1243 {
1244         char *ptr = ns->value;
1245         size_t remain = ns->length;
1246         unsigned int size;
1247         int digits;
1248         int sep;
1249         int rc;
1250         int i;
1251
1252         for (i = 0; i < numbufs; i++) {
1253                 /* read the size of first buffer */
1254                 rc = sscanf(ptr, "%9u", &size);
1255                 if (rc < 1)
1256                         goto out_err;
1257                 digits = (size) ? ceil(log10(size + 1)) : 1;
1258
1259                 /* sep of current string */
1260                 sep = size + digits + 2;
1261
1262                 /* check to make sure it's valid */
1263                 if (remain < sep || ptr[digits] != ':' ||
1264                     ptr[sep - 1] != ',')
1265                         goto out_err;
1266
1267                 bufs[i].length = size;
1268                 if (size == 0) {
1269                         bufs[i].value = NULL;
1270                 } else {
1271                         bufs[i].value = malloc(size);
1272                         if (!bufs[i].value)
1273                                 goto out_err;
1274                         memcpy(bufs[i].value, &ptr[digits + 1], size);
1275                 }
1276
1277                 remain -= sep;
1278                 ptr += sep;
1279         }
1280
1281         printerr(2, "Decoded netstring of %zu bytes\n", ns->length);
1282         return i;
1283
1284 out_err:
1285         while (i-- > 0) {
1286                 if (bufs[i].value)
1287                         free(bufs[i].value);
1288                 bufs[i].length = 0;
1289         }
1290         return -1;
1291 }
1292
1293 /**
1294  * Creates a netstring in a gss_buffer_desc that consists of all
1295  * the gss_buffer_desc found in \a bufs.  The netstring should be treated
1296  * as binary as it can contain null characters.
1297  *
1298  * \param[in]           bufs            Array of gss_buffer_desc to use as input
1299  * \param[in]           numbufs         Number of buffers in array
1300  * \param[in,out]       ns              Destination gss_buffer_desc to hold
1301  *                                      netstring
1302  *
1303  * \return      -1      failure
1304  * \return      0       success
1305  */
1306 int sk_encode_netstring(gss_buffer_desc *bufs, int numbufs,
1307                         gss_buffer_desc *ns)
1308 {
1309         unsigned char *ptr;
1310         int size = 0;
1311         int rc;
1312         int i;
1313
1314         /* size of string in decimal, string size, colon, and comma */
1315         for (i = 0; i < numbufs; i++) {
1316
1317                 if (bufs[i].length == 0)
1318                         size += 3;
1319                 else
1320                         size += ceil(log10(bufs[i].length + 1)) +
1321                                 bufs[i].length + 2;
1322         }
1323
1324         ns->length = size;
1325         ns->value = malloc(ns->length);
1326         if (!ns->value) {
1327                 ns->length = 0;
1328                 return -1;
1329         }
1330
1331         ptr = ns->value;
1332         for (i = 0; i < numbufs; i++) {
1333                 /* size */
1334                 rc = snprintf((char *) ptr, size, "%zu:", bufs[i].length);
1335                 ptr += rc;
1336
1337                 /* contents */
1338                 memcpy(ptr, bufs[i].value, bufs[i].length);
1339                 ptr += bufs[i].length;
1340
1341                 /* delimeter */
1342                 *ptr++ = ',';
1343
1344                 size -= bufs[i].length + rc + 1;
1345
1346                 /* should not happen */
1347                 if (size < 0)
1348                         abort();
1349         }
1350
1351         printerr(2, "Encoded netstring of %zu bytes\n", ns->length);
1352         return 0;
1353 }