Whamcloud - gitweb
LU-8590 utils: remove duplicate code in lgss_sk
[fs/lustre-release.git] / lustre / utils / gss / lgss_sk.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2015, Trustees of Indiana University
24  *
25  * Author: Jeremy Filizetti <jfilizet@iu.edu>
26  */
27
28 #include <errno.h>
29 #include <fcntl.h>
30 #include <getopt.h>
31 #include <limits.h>
32 #include <stdarg.h>
33 #include <stdbool.h>
34 #include <stdio.h>
35 #include <string.h>
36 #include <stdlib.h>
37 #include <sys/types.h>
38 #include <sys/stat.h>
39 #include <unistd.h>
40 #include <lnet/nidstr.h>
41 #include <lustre/lustre_idl.h>
42
43 #include "sk_utils.h"
44 #include "err_util.h"
45
46 #ifndef _GNU_SOURCE
47 #define _GNU_SOURCE
48 #endif
49
50 /* One week default expiration */
51 #define SK_DEFAULT_EXPIRE 604800
52 #define SK_DEFAULT_SK_KEYLEN 256
53 #define SK_DEFAULT_PRIME_BITS 2048
54 #define SK_DEFAULT_NODEMAP "default"
55
56 /* Names match up with openssl enc and dgst commands */
57 char *sk_crypt2name[] = {
58         [SK_CRYPT_EMPTY] = "NONE",
59         [SK_CRYPT_AES256_CTR] = "AES-256-CTR",
60 };
61
62 char *sk_hmac2name[] = {
63         [SK_HMAC_EMPTY] = "NONE",
64         [SK_HMAC_SHA256] = "SHA256",
65         [SK_HMAC_SHA512] = "SHA512",
66 };
67
68 static int sk_name2crypt(char *name)
69 {
70         int i;
71
72         for (i = 0; i < SK_CRYPT_MAX; i++) {
73                 if (strcasecmp(name, sk_crypt2name[i]) == 0)
74                         return i;
75         }
76
77         return SK_CRYPT_INVALID;
78 }
79
80 static int sk_name2hmac(char *name)
81 {
82         int i;
83
84         for (i = 0; i < SK_HMAC_MAX; i++) {
85                 if (strcasecmp(name, sk_hmac2name[i]) == 0)
86                         return i;
87         }
88
89         return SK_HMAC_INVALID;
90 }
91
92 static void usage(FILE *fp, char *program)
93 {
94         int i;
95
96         fprintf(fp, "Usage %s [OPTIONS] {-l|-m|-r|-w} <keyfile>\n", program);
97         fprintf(fp, "-l|--load       <keyfile>  Load key from file into user's "
98                 "session keyring\n");
99         fprintf(fp, "-m|--modify     <keyfile>  Modify keyfile's attributes\n");
100         fprintf(fp, "-r|--read       <keyfile>  Show keyfile's attributes\n");
101         fprintf(fp, "-w|--write      <keyfile>  Generate keyfile\n\n");
102         fprintf(fp, "Modify/Write Options:\n");
103         fprintf(fp, "-c|--crypt      <num>      Cipher for encryption "
104                 "(Default: AES Counter mode)\n");
105         for (i = 1; i < SK_CRYPT_MAX; i++)
106                 fprintf(fp, "                        %s\n", sk_crypt2name[i]);
107
108         fprintf(fp, "-i|--hmac       <num>      Hash algorithm for integrity "
109                 "(Default: SHA256)\n");
110         for (i = 1; i < SK_HMAC_MAX; i++)
111                 fprintf(fp, "                        %s\n", sk_hmac2name[i]);
112
113         fprintf(fp, "-e|--expire     <num>      Seconds before contexts from "
114                 "key expire (Default: %d seconds (%.3g days))\n",
115                 SK_DEFAULT_EXPIRE, (double)SK_DEFAULT_EXPIRE / 3600 / 24);
116         fprintf(fp, "-f|--fsname     <name>     File system name for key\n");
117         fprintf(fp, "-g|--mgsnids    <nids>     Comma seperated list of MGS "
118                 "NIDs.  Only required when mgssec is used (Default: \"\")\n");
119         fprintf(fp, "-n|--nodemap    <name>     Nodemap name for key "
120                 "(Default: \"%s\")\n", SK_DEFAULT_NODEMAP);
121         fprintf(fp, "-p|--prime-bits <len>      Prime length (p) for DHKE in "
122                 "bits (Default: %d)\n", SK_DEFAULT_PRIME_BITS);
123         fprintf(fp, "-t|--type       <type>     Key type (mgs, server, "
124                 "client)\n");
125         fprintf(fp, "-k|--key-bits   <len>      Shared key length in bits "
126                 "(Default: %d)\n", SK_DEFAULT_SK_KEYLEN);
127         fprintf(fp, "-d|--data       <file>     Key random data source "
128                 "(Default: /dev/random)\n\n");
129         fprintf(fp, "Other Options:\n");
130         fprintf(fp, "-v|--verbose           Increase verbosity for errors\n");
131         exit(EXIT_FAILURE);
132 }
133
134 static ssize_t get_key_data(char *src, void *buffer, size_t bits)
135 {
136         char *ptr = buffer;
137         size_t remain;
138         ssize_t rc;
139         int fd;
140
141         /* convert bits to minimum number of bytes */
142         remain = (bits + 7) / 8;
143
144         printf("Reading random data for shared key from '%s'\n", src);
145         fd = open(src, O_RDONLY);
146         if (fd < 0) {
147                 fprintf(stderr, "error: opening '%s': %s\n", src,
148                         strerror(errno));
149                 return -errno;
150         }
151
152         while (remain > 0) {
153                 rc = read(fd, ptr, remain);
154                 if (rc < 0) {
155                         if (errno == EINTR)
156                                 continue;
157                         fprintf(stderr, "error: reading from '%s': %s\n", src,
158                                 strerror(errno));
159                         rc = -errno;
160                         goto out;
161
162                 } else if (rc == 0) {
163                         fprintf(stderr,
164                                 "error: key source too short for %zd-bit key\n",
165                                 bits);
166                         rc = -ENODATA;
167                         goto out;
168                 }
169                 ptr += rc;
170                 remain -= rc;
171         }
172         rc = 0;
173
174 out:
175         close(fd);
176         return rc;
177 }
178
179 static int write_config_file(char *output_file,
180                              struct sk_keyfile_config *config, bool overwrite)
181 {
182         size_t rc;
183         int fd;
184         int flags = O_WRONLY | O_CREAT;
185
186         if (!overwrite)
187                 flags |= O_EXCL;
188
189         sk_config_cpu_to_disk(config);
190
191         fd = open(output_file, flags, 0400);
192         if (fd < 0) {
193                 fprintf(stderr, "error: opening '%s': %s\n", output_file,
194                         strerror(errno));
195                 return -errno;
196         }
197
198         rc = write(fd, config, sizeof(*config));
199         if (rc < 0) {
200                 fprintf(stderr, "error: writing to '%s': %s\n", output_file,
201                         strerror(errno));
202                 rc = -errno;
203         } else if (rc != sizeof(*config)) {
204                 fprintf(stderr, "error: short write to '%s'\n", output_file);
205                 rc = -ENOSPC;
206
207         } else {
208                 rc = 0;
209         }
210
211         close(fd);
212         return rc;
213 }
214
215 static int print_config(char *filename)
216 {
217         struct sk_keyfile_config *config;
218         int i;
219
220         config = sk_read_file(filename);
221         if (!config)
222                 return EXIT_FAILURE;
223
224         if (sk_validate_config(config)) {
225                 fprintf(stderr, "error: key configuration failed validation\n");
226                 free(config);
227                 return EXIT_FAILURE;
228         }
229
230         printf("Version:        %u\n", config->skc_version);
231         printf("Type:          ");
232         if (config->skc_type & SK_TYPE_MGS)
233                 printf(" mgs");
234         if (config->skc_type & SK_TYPE_SERVER)
235                 printf(" server");
236         if (config->skc_type & SK_TYPE_CLIENT)
237                 printf(" client");
238         printf("\n");
239         printf("HMAC alg:       %s\n", sk_hmac2name[config->skc_hmac_alg]);
240         printf("Crypto alg:     %s\n", sk_crypt2name[config->skc_crypt_alg]);
241         printf("Ctx Expiration: %u seconds\n", config->skc_expire);
242         printf("Shared keylen:  %u bits\n", config->skc_shared_keylen);
243         printf("Prime length:   %u bits\n", config->skc_prime_bits);
244         printf("File system:    %s\n", config->skc_fsname);
245         printf("MGS NIDs:      ");
246         for (i = 0; i < MAX_MGSNIDS; i++) {
247                 if (config->skc_mgsnids[i] == LNET_NID_ANY)
248                         continue;
249                 printf(" %s", libcfs_nid2str(config->skc_mgsnids[i]));
250         }
251         printf("\n");
252         printf("Nodemap name:   %s\n", config->skc_nodemap);
253         printf("Shared key:\n");
254         print_hex(0, config->skc_shared_key, config->skc_shared_keylen / 8);
255
256         /* Don't print empty keys */
257         for (i = 0; i < SK_MAX_P_BYTES; i++)
258                 if (config->skc_p[i] != 0)
259                         break;
260
261         if (i != SK_MAX_P_BYTES) {
262                 printf("Prime (p):\n");
263                 print_hex(0, config->skc_p, config->skc_prime_bits / 8);
264         }
265
266         free(config);
267         return EXIT_SUCCESS;
268 }
269
270 static int parse_mgsnids(char *mgsnids, struct sk_keyfile_config *config)
271 {
272         lnet_nid_t nid;
273         char *ptr;
274         char *sep;
275         char *end;
276         int rc = 0;
277         int i;
278
279         /* replace all old values */
280         for (i = 0; i < MAX_MGSNIDS; i++)
281                 config->skc_mgsnids[i] = LNET_NID_ANY;
282
283         i = 0;
284         end = mgsnids + strlen(mgsnids);
285         ptr = mgsnids;
286         while (ptr < end && i < MAX_MGSNIDS) {
287                 sep = strstr(ptr, ",");
288                 if (sep != NULL)
289                         *sep = '\0';
290
291                 nid = libcfs_str2nid(ptr);
292                 if (nid == LNET_NID_ANY) {
293                         fprintf(stderr, "error: invalid MGS NID: %s\n", ptr);
294                         rc = -EINVAL;
295                         break;
296                 }
297
298                 config->skc_mgsnids[i++] = nid;
299                 ptr += strlen(ptr) + 1;
300         }
301
302         if (i == MAX_MGSNIDS) {
303                 fprintf(stderr, "error: more than %u MGS NIDs provided\n", i);
304                 rc = -E2BIG;
305         }
306
307         return rc;
308 }
309
310 int main(int argc, char **argv)
311 {
312         struct sk_keyfile_config *config;
313         char *datafile = NULL;
314         char *input = NULL;
315         char *load = NULL;
316         char *modify = NULL;
317         char *output = NULL;
318         char *mgsnids = NULL;
319         char *nodemap = NULL;
320         char *fsname = NULL;
321         char *tmp;
322         char *tmp2;
323         int crypt = SK_CRYPT_EMPTY;
324         int hmac = SK_HMAC_EMPTY;
325         int expire = -1;
326         int shared_keylen = -1;
327         int prime_bits = -1;
328         int verbose = 0;
329         int i;
330         int opt;
331         enum sk_key_type  type = SK_TYPE_INVALID;
332         bool generate_prime = false;
333         DH *dh;
334
335         static struct option long_opt[] = {
336                 {"crypt", 1, 0, 'c'},
337                 {"data", 1, 0, 'd'},
338                 {"expire", 1, 0, 'e'},
339                 {"fsname", 1, 0, 'f'},
340                 {"mgsnids", 1, 0, 'g'},
341                 {"help", 0, 0, 'h'},
342                 {"hmac", 1, 0, 'i'},
343                 {"integrity", 1, 0, 'i'},
344                 {"key-bits", 1, 0, 'k'},
345                 {"shared", 1, 0, 'k'},
346                 {"load", 1, 0, 'l'},
347                 {"modify", 1, 0, 'm'},
348                 {"nodemap", 1, 0, 'n'},
349                 {"prime-bits", 1, 0, 'p'},
350                 {"read", 1, 0, 'r'},
351                 {"type", 1, 0, 't'},
352                 {"verbose", 0, 0, 'v'},
353                 {"write", 1, 0, 'w'},
354                 {0, 0, 0, 0},
355         };
356
357         while ((opt = getopt_long(argc, argv,
358                                   "c:d:e:f:g:hi:l:m:n:p:r:s:k:t:w:v", long_opt,
359                                   NULL)) != EOF) {
360                 switch (opt) {
361                 case 'c':
362                         crypt = sk_name2crypt(optarg);
363                         break;
364                 case 'd':
365                         datafile = optarg;
366                         break;
367                 case 'e':
368                         expire = atoi(optarg);
369                         if (expire < 60)
370                                 fprintf(stderr, "warning: using a %us key "
371                                         "expiration may cause issues during "
372                                         "key renegotiation\n", expire);
373                         break;
374                 case 'f':
375                         fsname = optarg;
376                         if (strlen(fsname) > MTI_NAME_MAXLEN) {
377                                 fprintf(stderr,
378                                         "error: file system name longer than "
379                                         "%u characters\n", MTI_NAME_MAXLEN);
380                                 return EXIT_FAILURE;
381                         }
382                         break;
383                 case 'g':
384                         mgsnids = optarg;
385                         break;
386                 case 'h':
387                         usage(stdout, argv[0]);
388                         break;
389                 case 'i':
390                         hmac = sk_name2hmac(optarg);
391                         break;
392                 case 'k':
393                         shared_keylen = atoi(optarg);
394                         break;
395                 case 'l':
396                         load = optarg;
397                         break;
398                 case 'm':
399                         modify = optarg;
400                         break;
401                 case 'n':
402                         nodemap = optarg;
403                         if (strlen(nodemap) > LUSTRE_NODEMAP_NAME_LENGTH) {
404                                 fprintf(stderr,
405                                         "error: nodemap name longer than "
406                                         "%u characters\n",
407                                         LUSTRE_NODEMAP_NAME_LENGTH);
408                                 return EXIT_FAILURE;
409                         }
410                         break;
411                 case 'p':
412                         prime_bits = atoi(optarg);
413                         if (prime_bits <= 0) {
414                                 fprintf(stderr,
415                                         "error: invalid prime length: '%s'\n",
416                                         optarg);
417                                 return EXIT_FAILURE;
418                         }
419                         break;
420                 case 'r':
421                         input = optarg;
422                         break;
423                 case 't':
424                         tmp2 = strdup(optarg);
425                         if (!tmp2) {
426                                 fprintf(stderr,
427                                         "error: failed to allocate type\n");
428                                 return EXIT_FAILURE;
429                         }
430                         tmp = strsep(&tmp2, ",");
431                         while (tmp != NULL) {
432                                 if (strcasecmp(tmp, "server") == 0) {
433                                         type |= SK_TYPE_SERVER;
434                                 } else if (strcasecmp(tmp, "mgs") == 0) {
435                                         type |= SK_TYPE_MGS;
436                                 } else if (strcasecmp(tmp, "client") == 0) {
437                                         type |= SK_TYPE_CLIENT;
438                                 } else {
439                                         fprintf(stderr,
440                                                 "error: invalid type '%s', "
441                                                 "must be mgs, server, or client"
442                                                 "\n", optarg);
443                                         return EXIT_FAILURE;
444                                 }
445                                 tmp = strsep(&tmp2, ",");
446                         }
447                         free(tmp2);
448                         break;
449                 case 'v':
450                         verbose++;
451                         break;
452                 case 'w':
453                         output = optarg;
454                         break;
455                 default:
456                         fprintf(stderr, "error: unknown option: '%c'\n", opt);
457                         return EXIT_FAILURE;
458                         break;
459                 }
460         }
461
462         if (optind != argc) {
463                 fprintf(stderr,
464                         "error: extraneous arguments provided, check usage\n");
465                 return EXIT_FAILURE;
466         }
467
468         if (!input && !output && !load && !modify) {
469                 usage(stderr, argv[0]);
470                 return EXIT_FAILURE;
471         }
472
473         /* init gss logger for foreground (no syslog) which prints to stderr */
474         initerr(NULL, verbose, 1);
475
476         if (input)
477                 return print_config(input);
478
479         if (load) {
480                 if (sk_load_keyfile(load))
481                         return EXIT_FAILURE;
482                 return EXIT_SUCCESS;
483         }
484
485         if (crypt == SK_CRYPT_INVALID) {
486                 fprintf(stderr, "error: invalid crypt algorithm specified\n");
487                 return EXIT_FAILURE;
488         }
489         if (hmac == SK_HMAC_INVALID) {
490                 fprintf(stderr, "error: invalid HMAC algorithm specified\n");
491                 return EXIT_FAILURE;
492         }
493
494         if (modify) {
495                 config = sk_read_file(modify);
496                 if (!config)
497                         return EXIT_FAILURE;
498
499                 if (type != SK_TYPE_INVALID) {
500                         /* generate key when adding client type */
501                         if (!(config->skc_type & SK_TYPE_CLIENT) &&
502                             type & SK_TYPE_CLIENT)
503                                 generate_prime = true;
504                         else if (!(type & SK_TYPE_CLIENT))
505                                 memset(config->skc_p, 0, SK_MAX_P_BYTES);
506
507                         config->skc_type = type;
508                 }
509                 if (prime_bits != -1) {
510                         memset(config->skc_p, 0, SK_MAX_P_BYTES);
511                         if (config->skc_prime_bits != prime_bits &&
512                             config->skc_type & SK_TYPE_CLIENT)
513                                 generate_prime = true;
514                 }
515         } else {
516                 /* write mode for a new key */
517                 if (!fsname && !mgsnids) {
518                         fprintf(stderr,
519                                 "error: missing --fsname or --mgsnids\n");
520                         return EXIT_FAILURE;
521                 }
522
523                 config = calloc(1, sizeof(*config));
524                 if (!config)
525                         return EXIT_FAILURE;
526
527                 /* Set the defaults for new key */
528                 config->skc_version = SK_CONF_VERSION;
529                 config->skc_expire = SK_DEFAULT_EXPIRE;
530                 config->skc_shared_keylen = SK_DEFAULT_SK_KEYLEN;
531                 config->skc_prime_bits = SK_DEFAULT_PRIME_BITS;
532                 config->skc_crypt_alg = SK_CRYPT_AES256_CTR;
533                 config->skc_hmac_alg = SK_HMAC_SHA256;
534                 for (i = 0; i < MAX_MGSNIDS; i++)
535                         config->skc_mgsnids[i] = LNET_NID_ANY;
536
537                 if (type == SK_TYPE_INVALID) {
538                         fprintf(stderr, "error: no type specified for key\n");
539                         goto error;
540                 }
541                 config->skc_type = type;
542                 generate_prime = type & SK_TYPE_CLIENT;
543
544                 strncpy(config->skc_nodemap, SK_DEFAULT_NODEMAP,
545                         strlen(SK_DEFAULT_NODEMAP));
546
547                 if (!datafile)
548                         datafile = "/dev/random";
549         }
550
551         if (crypt != SK_CRYPT_EMPTY)
552                 config->skc_crypt_alg = crypt;
553         if (hmac != SK_HMAC_EMPTY)
554                 config->skc_hmac_alg = hmac;
555         if (expire != -1)
556                 config->skc_expire = expire;
557         if (shared_keylen != -1)
558                 config->skc_shared_keylen = shared_keylen;
559         if (prime_bits != -1)
560                 config->skc_prime_bits = prime_bits;
561         if (fsname)
562                 strncpy(config->skc_fsname, fsname, strlen(fsname));
563         if (nodemap)
564                 strncpy(config->skc_nodemap, nodemap, strlen(nodemap));
565         if (mgsnids && parse_mgsnids(mgsnids, config))
566                 goto error;
567         if (sk_validate_config(config)) {
568                 fprintf(stderr, "error: key configuration failed validation\n");
569                 goto error;
570         }
571
572         if (datafile && get_key_data(datafile, config->skc_shared_key,
573                                      config->skc_shared_keylen)) {
574                 fprintf(stderr, "error: failure getting key data from '%s'\n",
575                         datafile);
576                 goto error;
577         }
578
579         if (generate_prime) {
580                 printf("Generating DH parameters, this can take a while...\n");
581                 dh = DH_generate_parameters(config->skc_prime_bits,
582                                             SK_GENERATOR, NULL, NULL);
583                 if (BN_num_bytes(dh->p) > SK_MAX_P_BYTES) {
584                         fprintf(stderr, "error: cannot generate DH parameters: "
585                                 "requested length %d exceeds maximum %d\n",
586                                 config->skc_prime_bits, SK_MAX_P_BYTES * 8);
587                         goto error;
588                 }
589                 if (BN_bn2bin(dh->p, config->skc_p) != BN_num_bytes(dh->p)) {
590                         fprintf(stderr,
591                                 "error: convert BIGNUM p to binary failed\n");
592                         goto error;
593                 }
594         }
595
596         if (write_config_file(modify ?: output, config, modify))
597                 goto error;
598
599         return EXIT_SUCCESS;
600
601 error:
602         free(config);
603         return EXIT_FAILURE;
604 }