Whamcloud - gitweb
b3c9889126a0b09d7d0b679ec4fa43662b92c49e
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
1 /*
2  * svc_in_gssd_proc.c
3  *
4  * Copyright (c) 2000 The Regents of the University of Michigan.
5  * All rights reserved.
6  *
7  * Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  * 3. Neither the name of the University nor the names of its
19  *    contributors may be used to endorse or promote products derived
20  *    from this software without specific prior written permission.
21  *
22  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
23  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25  * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
26  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29  * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/stat.h>
37
38 #include <inttypes.h>
39 #include <pwd.h>
40 #include <stdio.h>
41 #include <unistd.h>
42 #include <ctype.h>
43 #include <string.h>
44 #include <fcntl.h>
45 #include <errno.h>
46 #ifdef HAVE_NETDB_H
47 # include <netdb.h>
48 #endif
49
50 #include <stdbool.h>
51
52 #include "svcgssd.h"
53 #include "gss_util.h"
54 #include "err_util.h"
55 #include "context.h"
56 #include "cacheio.h"
57 #include "lsupport.h"
58 #include "gss_oids.h"
59 #include <time.h>
60 #include <linux/lustre/lustre_idl.h>
61 #include "sk_utils.h"
62 #include <sys/time.h>
63 #include <gssapi/gssapi_krb5.h>
64 #include <libcfs/util/param.h>
65
66 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.sptlrpc.context/channel"
67 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.sptlrpc.init/channel"
68
69 struct svc_cred {
70         uint32_t cr_remote;
71         uint32_t cr_usr_root;
72         uint32_t cr_usr_mds;
73         uint32_t cr_usr_oss;
74         uid_t    cr_uid;
75         uid_t    cr_mapped_uid;
76         uid_t    cr_gid;
77 };
78
79 struct svc_nego_data {
80         /* kernel data*/
81         uint32_t        lustre_svc;
82         lnet_nid_t      nid;
83         uint64_t        handle_seq;
84         char            nm_name[LUSTRE_NODEMAP_NAME_LENGTH + 1];
85         gss_buffer_desc in_tok;
86         gss_buffer_desc out_tok;
87         gss_buffer_desc in_handle;
88         gss_buffer_desc out_handle;
89         uint32_t        maj_stat;
90         uint32_t        min_stat;
91
92         /* userspace data */
93         gss_OID                 mech;
94         gss_ctx_id_t            ctx;
95         gss_buffer_desc         ctx_token;
96 };
97
98 static int do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
99                            gss_OID mechoid, gss_buffer_desc *ctx_token)
100 {
101         struct rsc_downcall_data *rsc_dd;
102         int blen, fd, size, rc = -1;
103         const char *mechname;
104         glob_t path;
105         char *bp;
106
107         printerr(LL_INFO, "doing downcall\n");
108
109         size = out_handle->length + sizeof(__u32) +
110                 ctx_token->length + sizeof(__u32);
111         blen = size;
112
113         size += offsetof(struct rsc_downcall_data, scd_val[0]);
114         rsc_dd = calloc(1, size);
115         if (!rsc_dd) {
116                 printerr(LL_ERR, "malloc downcall data (%d) failed\n", size);
117                 return -ENOMEM;
118         }
119         rsc_dd->scd_magic = RSC_DOWNCALL_MAGIC;
120         rsc_dd->scd_err = 0;
121
122         rsc_dd->scd_flags =
123                 (cred->cr_remote ? RSC_DATA_FLAG_REMOTE : 0) |
124                 (cred->cr_usr_root ? RSC_DATA_FLAG_ROOT : 0) |
125                 (cred->cr_usr_mds ? RSC_DATA_FLAG_MDS : 0) |
126                 (cred->cr_usr_oss ? RSC_DATA_FLAG_OSS : 0);
127         rsc_dd->scd_mapped_uid = cred->cr_mapped_uid;
128         rsc_dd->scd_uid = cred->cr_uid;
129         rsc_dd->scd_gid = cred->cr_gid;
130         mechname = gss_OID_mech_name(mechoid);
131         if (mechname == NULL)
132                 goto out;
133         if (snprintf(rsc_dd->scd_mechname, sizeof(rsc_dd->scd_mechname),
134                      "%s", mechname) >= sizeof(rsc_dd->scd_mechname))
135                 goto out;
136
137         bp = rsc_dd->scd_val;
138         gss_buffer_write(&bp, &blen, out_handle->value, out_handle->length);
139         gss_buffer_write(&bp, &blen, ctx_token->value, ctx_token->length);
140         if (blen < 0) {
141                 printerr(LL_ERR, "ERROR: %s: message too long > %d\n",
142                          __func__, size);
143                 rc = -EMSGSIZE;
144                 goto out;
145         }
146         rsc_dd->scd_len = bp - rsc_dd->scd_val;
147
148         rc = cfs_get_param_paths(&path, RSC_DOWNCALL_PATH);
149         if (rc != 0) {
150                 rc = -errno;
151                 goto out;
152         }
153
154         fd = open(path.gl_pathv[0], O_WRONLY);
155         if (fd == -1) {
156                 rc = -errno;
157                 printerr(LL_ERR, "ERROR: %s: open %s failed: %s\n",
158                          __func__, RSC_DOWNCALL_PATH, strerror(-rc));
159                 goto out_path;
160         }
161         size = offsetof(struct rsc_downcall_data,
162                         scd_val[bp - rsc_dd->scd_val]);
163         printerr(LL_DEBUG, "writing downcall data, size %d\n", size);
164         if (write(fd, rsc_dd, size) == -1) {
165                 rc = -errno;
166                 printerr(LL_ERR, "ERROR: %s failed: %s\n",
167                          __func__, strerror(-rc));
168         }
169         printerr(LL_DEBUG, "downcall data written ok\n");
170
171         close(fd);
172 out_path:
173         cfs_free_param_data(&path);
174 out:
175         free(rsc_dd);
176         if (rc)
177                 printerr(LL_ERR, "ERROR: downcall failed\n");
178         return rc;
179 }
180
181 struct gss_verifier {
182         u_int32_t       flav;
183         gss_buffer_desc body;
184 };
185
186 #define RPCSEC_GSS_SEQ_WIN      5
187
188 static int send_response(int auth_res, uint64_t hash,
189                         gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
190                         u_int32_t maj_stat, u_int32_t min_stat,
191                         gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
192 {
193         struct rsi_downcall_data *rsi_dd;
194         int blen, fd, size, rc = 0;
195         glob_t path;
196         char *bp;
197
198         printerr(LL_INFO, "sending reply\n");
199
200         size = in_handle->length + sizeof(__u32) +
201                 in_token->length + sizeof(__u32) +
202                 sizeof(__u32) + sizeof(__u32);
203         if (!auth_res)
204                 size += out_handle->length + out_token->length;
205         blen = size;
206
207         size += offsetof(struct rsi_downcall_data, sid_val[0]);
208         rsi_dd = calloc(1, size);
209         if (!rsi_dd) {
210                 printerr(LL_ERR, "malloc downcall data (%d) failed\n", size);
211                 return -ENOMEM;
212         }
213         rsi_dd->sid_magic = RSI_DOWNCALL_MAGIC;
214         rsi_dd->sid_hash = hash;
215         rsi_dd->sid_maj_stat = maj_stat;
216         rsi_dd->sid_min_stat = min_stat;
217
218         bp = rsi_dd->sid_val;
219         gss_buffer_write(&bp, &blen, in_handle->value, in_handle->length);
220         gss_buffer_write(&bp, &blen, in_token->value, in_token->length);
221         if (!auth_res) {
222                 gss_buffer_write(&bp, &blen, out_handle->value,
223                                  out_handle->length);
224                 gss_buffer_write(&bp, &blen, out_token->value,
225                                  out_token->length);
226         } else {
227                 rsi_dd->sid_err = -EACCES;
228                 gss_buffer_write(&bp, &blen, NULL, 0);
229                 gss_buffer_write(&bp, &blen, NULL, 0);
230         }
231         if (blen < 0) {
232                 printerr(LL_ERR, "ERROR: %s: message too long > %d\n",
233                          __func__, size);
234                 rc = -EMSGSIZE;
235                 goto out;
236         }
237         rsi_dd->sid_len = bp - rsi_dd->sid_val;
238
239         rc = cfs_get_param_paths(&path, RSI_DOWNCALL_PATH);
240         if (rc != 0) {
241                 rc = -errno;
242                 printerr(LL_ERR, "ERROR: %s: cannot get param path %s: %s\n",
243                          __func__, RSI_DOWNCALL_PATH, strerror(-rc));
244                 goto out;
245         }
246
247         fd = open(path.gl_pathv[0], O_WRONLY);
248         if (fd == -1) {
249                 rc = -errno;
250                 printerr(LL_ERR, "ERROR: %s: open %s failed: %s\n",
251                          __func__, RSI_DOWNCALL_PATH, strerror(-rc));
252                 goto out_path;
253         }
254         size = offsetof(struct rsi_downcall_data,
255                         sid_val[bp - rsi_dd->sid_val]);
256         printerr(LL_DEBUG, "writing response, size %d\n", size);
257         if (write(fd, rsi_dd, size) == -1) {
258                 rc = -errno;
259                 printerr(LL_ERR, "ERROR: %s failed: %s\n",
260                          __func__, strerror(-rc));
261         }
262         printerr(LL_DEBUG, "response written ok\n");
263
264         close(fd);
265 out_path:
266         cfs_free_param_data(&path);
267 out:
268         free(rsi_dd);
269         return rc;
270 }
271
272 #define rpc_auth_ok                     0
273 #define rpc_autherr_badcred             1
274 #define rpc_autherr_rejectedcred        2
275 #define rpc_autherr_badverf             3
276 #define rpc_autherr_rejectedverf        4
277 #define rpc_autherr_tooweak             5
278 #define rpcsec_gsserr_credproblem       13
279 #define rpcsec_gsserr_ctxproblem        14
280
281 static int lookup_localname(gss_name_t client_name, char *princ, lnet_nid_t nid,
282                             uid_t *uid)
283 {
284         u_int32_t maj_stat, min_stat;
285         gss_buffer_desc localname;
286         char *sname;
287         int rc = -1;
288
289         *uid = -1;
290         maj_stat = gss_localname(&min_stat, client_name, GSS_C_NO_OID,
291                                  &localname);
292         if (maj_stat != GSS_S_COMPLETE) {
293                 printerr(LL_INFO, "no local name for %s/%#Lx\n", princ, nid);
294                 return rc;
295         }
296
297         sname = calloc(localname.length + 1, 1);
298         if (!sname) {
299                 printerr(LL_ERR, "%s: error allocating %zu bytes\n",
300                          __func__, localname.length + 1);
301                 goto free;
302         }
303         memcpy(sname, localname.value, localname.length);
304         sname[localname.length] = '\0';
305
306         *uid = parse_uid(sname);
307         free(sname);
308         printerr(LL_WARN, "found local uid: %s ==> %d\n", princ, *uid);
309         rc = 0;
310
311 free:
312         gss_release_buffer(&min_stat, &localname);
313         return rc;
314 }
315
316 static int lookup_id(gss_name_t client_name, char *princ, lnet_nid_t nid,
317                      uid_t *uid)
318 {
319         if (!mapping_empty())
320                 return lookup_mapping(princ, nid, uid);
321
322         return lookup_localname(client_name, princ, nid, uid);
323 }
324
325 static int
326 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred,
327         lnet_nid_t nid, uint32_t lustre_svc)
328 {
329         u_int32_t       maj_stat, min_stat;
330         gss_buffer_desc name;
331         char            *sname, *host, *realm;
332         const int       namebuf_size = 512;
333         char            namebuf[namebuf_size];
334         int             res = -1;
335         gss_OID         name_type = GSS_C_NO_OID;
336         struct passwd   *pw;
337
338         cred->cr_remote = 0;
339         cred->cr_usr_root = cred->cr_usr_mds = cred->cr_usr_oss = 0;
340         cred->cr_uid = cred->cr_mapped_uid = cred->cr_gid = -1;
341
342         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
343         if (maj_stat != GSS_S_COMPLETE) {
344                 pgsserr("get_ids: gss_display_name",
345                         maj_stat, min_stat, mech);
346                 return -1;
347         }
348         /* be certain name.length+1 doesn't overflow */
349         if (name.length >= 0xffff ||
350             !(sname = calloc(name.length + 1, 1))) {
351                 printerr(LL_ERR,
352                          "ERROR: %s: error allocating %zu bytes for sname\n",
353                          __func__, name.length + 1);
354                 gss_release_buffer(&min_stat, &name);
355                 return -1;
356         }
357         memcpy(sname, name.value, name.length);
358         sname[name.length] = '\0';
359         gss_release_buffer(&min_stat, &name);
360
361         if (lustre_svc == LUSTRE_GSS_SVC_MDS &&
362             lookup_id(client_name, sname, nid, &cred->cr_mapped_uid))
363                 printerr(LL_DEBUG, "no id found for %s\n", sname);
364
365         realm = strchr(sname, '@');
366         if (realm) {
367                 *realm++ = '\0';
368         } else {
369                 printerr(LL_ERR, "ERROR: %s has no realm name\n", sname);
370                 goto out_free;
371         }
372
373         host = strchr(sname, '/');
374         if (host)
375                 *host++ = '\0';
376
377         if (strcmp(sname, GSSD_SERVICE_MGS) == 0) {
378                 printerr(LL_ERR, "forbid %s as a user name\n", sname);
379                 goto out_free;
380         }
381
382         /* 1. check host part */
383         if (host) {
384                 if (lnet_nid2hostname(nid, namebuf, namebuf_size)) {
385                         printerr(LL_ERR,
386                                  "ERROR: failed to resolve hostname for %s/%s@%s from %016llx\n",
387                                  sname, host, realm, nid);
388                         goto out_free;
389                 }
390
391                 if (strcasecmp(host, namebuf)) {
392                         printerr(LL_ERR,
393                                  "ERROR: %s/%s@%s claimed hostname doesn't match %s, nid %016llx\n",
394                                  sname, host, realm,
395                                  namebuf, nid);
396                         goto out_free;
397                 }
398         } else {
399                 if (!strcmp(sname, GSSD_SERVICE_MDS) ||
400                     !strcmp(sname, GSSD_SERVICE_OSS)) {
401                         printerr(LL_ERR,
402                                  "ERROR: %s@%s from %016llx doesn't bind with hostname\n",
403                                  sname, realm, nid);
404                         goto out_free;
405                 }
406         }
407
408         /* 2. check realm and user */
409         switch (lustre_svc) {
410         case LUSTRE_GSS_SVC_MDS:
411                 if (strcasecmp(mds_local_realm, realm) != 0) {
412                         /* Remote realm case */
413                         cred->cr_remote = 1;
414
415                         /* Prevent access to unmapped user from remote realm */
416                         if (cred->cr_mapped_uid == -1) {
417                                 printerr(LL_ERR,
418                                          "ERROR: %s%s%s@%s from %016llx is remote but without mapping\n",
419                                          sname, host ? "/" : "",
420                                          host ? host : "", realm, nid);
421                                 break;
422                         }
423                         goto valid;
424                 }
425
426                 /* Now we know we are dealing with a local realm */
427
428                 if (!strcmp(sname, LUSTRE_ROOT_NAME) ||
429                     !strcmp(sname, GSSD_SERVICE_HOST)) {
430                         cred->cr_uid = 0;
431                         cred->cr_usr_root = 1;
432                         goto valid;
433                 }
434                 if (!strcmp(sname, GSSD_SERVICE_MDS)) {
435                         cred->cr_uid = 0;
436                         cred->cr_usr_mds = 1;
437                         goto valid;
438                 }
439                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
440                         cred->cr_uid = 0;
441                         cred->cr_usr_oss = 1;
442                         goto valid;
443                 }
444                 if (cred->cr_mapped_uid != -1) {
445                         printerr(LL_INFO,
446                                  "user %s from %016llx is mapped to %u\n",
447                                  sname, nid,
448                                  cred->cr_mapped_uid);
449                         goto valid;
450                 }
451                 pw = getpwnam(sname);
452                 if (pw != NULL) {
453                         cred->cr_uid = pw->pw_uid;
454                         printerr(LL_INFO, "%s resolve to uid %u\n",
455                                  sname, cred->cr_uid);
456                         goto valid;
457                 }
458                 printerr(LL_ERR, "ERROR: invalid user, %s/%s@%s from %016llx\n",
459                          sname, host, realm, nid);
460                 break;
461
462 valid:
463                 res = 0;
464                 break;
465         case LUSTRE_GSS_SVC_MGS:
466                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
467                         cred->cr_uid = 0;
468                         cred->cr_usr_oss = 1;
469                 }
470                 fallthrough;
471         case LUSTRE_GSS_SVC_OSS:
472                 if (!strcmp(sname, LUSTRE_ROOT_NAME) ||
473                     !strcmp(sname, GSSD_SERVICE_HOST)) {
474                         cred->cr_uid = 0;
475                         cred->cr_usr_root = 1;
476                 } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
477                         cred->cr_uid = 0;
478                         cred->cr_usr_mds = 1;
479                 }
480                 if (cred->cr_uid == -1) {
481                         printerr(LL_ERR,
482                                  "ERROR: svc %d doesn't accept user %s from %016llx\n",
483                                  lustre_svc, sname, nid);
484                         break;
485                 }
486                 res = 0;
487                 break;
488         default:
489                 assert(0);
490         }
491
492 out_free:
493         if (!res)
494                 printerr(LL_WARN, "%s: authenticated %s%s%s@%s from %016llx\n",
495                          lustre_svc_name[lustre_svc], sname,
496                          host ? "/" : "", host ? host : "", realm, nid);
497         free(sname);
498         return res;
499 }
500
501 typedef struct gss_union_ctx_id_t {
502         gss_OID         mech_type;
503         gss_ctx_id_t    internal_ctx_id;
504 } gss_union_ctx_id_desc, *gss_union_ctx_id_t;
505
506 int handle_sk(struct svc_nego_data *snd)
507 {
508 #ifdef HAVE_OPENSSL_SSK
509         struct sk_cred *skc = NULL;
510         struct svc_cred cred;
511         gss_buffer_desc bufs[SK_INIT_BUFFERS];
512         gss_buffer_desc remote_pub_key = GSS_C_EMPTY_BUFFER;
513         char *target;
514         uint32_t rc = GSS_S_DEFECTIVE_TOKEN;
515         uint32_t version;
516         uint32_t flags;
517         int i;
518         int attempts = 0;
519
520         printerr(LL_DEBUG, "Handling sk request\n");
521         memset(bufs, 0, sizeof(gss_buffer_desc) * SK_INIT_BUFFERS);
522
523         /* See lgss_sk_using_cred() for client side token formation.
524          * Decoding initiator buffers */
525         i = sk_decode_netstring(bufs, SK_INIT_BUFFERS, &snd->in_tok);
526         if (i < SK_INIT_BUFFERS) {
527                 printerr(LL_ERR,
528                          "Invalid netstring token received from peer\n");
529                 goto cleanup_buffers;
530         }
531
532         /* Allowing for a larger length first buffer in the future */
533         if (bufs[SK_INIT_VERSION].length < sizeof(version)) {
534                 printerr(LL_ERR, "Invalid version received (wrong size)\n");
535                 goto cleanup_buffers;
536         }
537         memcpy(&version, bufs[SK_INIT_VERSION].value, sizeof(version));
538         version = be32toh(version);
539         if (version != SK_MSG_VERSION) {
540                 printerr(LL_ERR, "Invalid version received: %d\n", version);
541                 goto cleanup_buffers;
542         }
543
544         rc = GSS_S_FAILURE;
545
546         /* target must be a null terminated string */
547         i = bufs[SK_INIT_TARGET].length - 1;
548         target = bufs[SK_INIT_TARGET].value;
549         if (i >= 0 && target[i] != '\0') {
550                 printerr(LL_ERR, "Invalid target from netstring\n");
551                 goto cleanup_buffers;
552         }
553
554         if (bufs[SK_INIT_FLAGS].length != sizeof(flags)) {
555                 printerr(LL_ERR, "Invalid flags from netstring\n");
556                 goto cleanup_buffers;
557         }
558         memcpy(&flags, bufs[SK_INIT_FLAGS].value, sizeof(flags));
559
560         skc = sk_create_cred(target, snd->nm_name, be32toh(flags));
561         if (!skc) {
562                 printerr(LL_ERR, "Failed to create sk credentials\n");
563                 goto cleanup_buffers;
564         }
565
566         /* Verify that the peer has used a prime size greater or equal to
567          * the size specified in the key file which may contain only zero
568          * fill but the size specifies the mimimum supported size on
569          * servers */
570         if (skc->sc_flags & LGSS_SVC_PRIV &&
571             bufs[SK_INIT_P].length < skc->sc_p.length) {
572                 printerr(LL_ERR,
573                          "Peer DHKE prime does not meet the size required by keyfile: %zd bits\n",
574                          skc->sc_p.length * 8);
575                 goto cleanup_buffers;
576         }
577
578         /* Throw out the p from the server and use the wire data */
579         free(skc->sc_p.value);
580         skc->sc_p.value = NULL;
581         skc->sc_p.length = 0;
582
583         /* Take control of all the allocated buffers from decoding */
584         if (bufs[SK_INIT_RANDOM].length !=
585             sizeof(skc->sc_kctx.skc_peer_random)) {
586                 printerr(LL_ERR, "Invalid size for client random\n");
587                 goto cleanup_buffers;
588         }
589
590         memcpy(&skc->sc_kctx.skc_peer_random, bufs[SK_INIT_RANDOM].value,
591                sizeof(skc->sc_kctx.skc_peer_random));
592         skc->sc_p = bufs[SK_INIT_P];
593         remote_pub_key = bufs[SK_INIT_PUB_KEY];
594         skc->sc_nodemap_hash = bufs[SK_INIT_NODEMAP];
595         skc->sc_hmac = bufs[SK_INIT_HMAC];
596
597         /* Verify HMAC from peer.  Ideally this would happen before anything
598          * else but we don't have enough information to lookup key without the
599          * token (fsname and cluster_hash) so it's done after. */
600         rc = sk_verify_hmac(skc, bufs, SK_INIT_BUFFERS - 1, EVP_sha256(),
601                             &skc->sc_hmac);
602         if (rc != GSS_S_COMPLETE) {
603                 printerr(LL_ERR, "HMAC verification error: 0x%x from peer %s\n",
604                          rc, libcfs_nid2str((lnet_nid_t)snd->nid));
605                 goto cleanup_partial;
606         }
607
608         /* Check that the cluster hash matches the hash of nodemap name */
609         rc = sk_verify_hash(snd->nm_name, EVP_sha256(), &skc->sc_nodemap_hash);
610         if (rc != GSS_S_COMPLETE) {
611                 printerr(LL_ERR, "Cluster hash failed validation: 0x%x\n", rc);
612                 goto cleanup_partial;
613         }
614
615 redo:
616         rc = sk_gen_params(skc, sk_dh_checks);
617         if (rc != GSS_S_COMPLETE) {
618                 printerr(LL_ERR,
619                          "Failed to generate DH params for responder\n");
620                 goto cleanup_partial;
621         }
622         rc = sk_compute_dh_key(skc, &remote_pub_key);
623         if (rc == GSS_S_BAD_QOP && attempts < 2) {
624                 /* GSS_S_BAD_QOP means the generated shared key was shorter
625                  * than expected. Just retry twice before giving up.
626                  */
627                 attempts++;
628                 if (skc->sc_params) {
629                         EVP_PKEY_free(skc->sc_params);
630                         skc->sc_params = NULL;
631                 }
632                 if (skc->sc_pub_key.value) {
633                         free(skc->sc_pub_key.value);
634                         skc->sc_pub_key.value = NULL;
635                 }
636                 skc->sc_pub_key.length = 0;
637                 if (skc->sc_dh_shared_key.value) {
638                         /* erase secret key before freeing memory */
639                         memset(skc->sc_dh_shared_key.value, 0,
640                                skc->sc_dh_shared_key.length);
641                         free(skc->sc_dh_shared_key.value);
642                         skc->sc_dh_shared_key.value = NULL;
643                 }
644                 skc->sc_dh_shared_key.length = 0;
645                 goto redo;
646         } else if (rc != GSS_S_COMPLETE) {
647                 printerr(LL_ERR,
648                          "Failed to compute session key from DH params\n");
649                 goto cleanup_partial;
650         }
651
652         /* Cleanup init buffers we have copied or don't need anymore */
653         free(bufs[SK_INIT_VERSION].value);
654         free(bufs[SK_INIT_RANDOM].value);
655         free(bufs[SK_INIT_TARGET].value);
656         free(bufs[SK_INIT_FLAGS].value);
657
658         /* Server reply contains the servers public key, random,  and HMAC */
659         version = htobe32(SK_MSG_VERSION);
660         bufs[SK_RESP_VERSION].value = &version;
661         bufs[SK_RESP_VERSION].length = sizeof(version);
662         bufs[SK_RESP_RANDOM].value = &skc->sc_kctx.skc_host_random;
663         bufs[SK_RESP_RANDOM].length = sizeof(skc->sc_kctx.skc_host_random);
664         bufs[SK_RESP_PUB_KEY] = skc->sc_pub_key;
665         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs,
666                          SK_RESP_BUFFERS - 1, EVP_sha256(),
667                          &skc->sc_hmac)) {
668                 printerr(LL_ERR, "Failed to sign parameters\n");
669                 goto out_err;
670         }
671         bufs[SK_RESP_HMAC] = skc->sc_hmac;
672         if (sk_encode_netstring(bufs, SK_RESP_BUFFERS, &snd->out_tok)) {
673                 printerr(LL_ERR, "Failed to encode netstring for token\n");
674                 goto out_err;
675         }
676         printerr(LL_INFO, "Created netstring of %zd bytes\n",
677                  snd->out_tok.length);
678
679         if (sk_session_kdf(skc, snd->nid, &snd->in_tok, &snd->out_tok)) {
680                 printerr(LL_ERR, "Failed to calculate derived session key\n");
681                 goto out_err;
682         }
683         if (sk_compute_keys(skc)) {
684                 printerr(LL_ERR,
685                          "Failed to compute HMAC and encryption keys\n");
686                 goto out_err;
687         }
688         if (sk_serialize_kctx(skc, &snd->ctx_token)) {
689                 printerr(LL_ERR, "Failed to serialize context for kernel\n");
690                 goto out_err;
691         }
692
693         snd->out_handle.length = sizeof(snd->handle_seq);
694         memcpy(snd->out_handle.value, &snd->handle_seq,
695                sizeof(snd->handle_seq));
696         snd->maj_stat = GSS_S_COMPLETE;
697
698         /* fix credentials */
699         memset(&cred, 0, sizeof(cred));
700         cred.cr_mapped_uid = -1;
701
702         if (skc->sc_flags & LGSS_ROOT_CRED_ROOT)
703                 cred.cr_usr_root = 1;
704         if (skc->sc_flags & LGSS_ROOT_CRED_MDT)
705                 cred.cr_usr_mds = 1;
706         if (skc->sc_flags & LGSS_ROOT_CRED_OST)
707                 cred.cr_usr_oss = 1;
708
709         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
710
711         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_request */
712         if (remote_pub_key.length != 0) {
713                 free(remote_pub_key.value);
714                 remote_pub_key.value = NULL;
715                 remote_pub_key.length = 0;
716         }
717         if (snd->ctx_token.value) {
718                 free(snd->ctx_token.value);
719                 snd->ctx_token.value = NULL;
720                 snd->ctx_token.length = 0;
721         }
722
723         printerr(LL_DEBUG, "sk returning success\n");
724         return 0;
725
726 cleanup_buffers:
727         for (i = 0; i < SK_INIT_BUFFERS; i++)
728                 free(bufs[i].value);
729         sk_free_cred(skc);
730         snd->maj_stat = rc;
731         return -1;
732
733 cleanup_partial:
734         free(bufs[SK_INIT_VERSION].value);
735         free(bufs[SK_INIT_RANDOM].value);
736         free(bufs[SK_INIT_TARGET].value);
737         free(bufs[SK_INIT_FLAGS].value);
738         if (remote_pub_key.length != 0) {
739                 free(remote_pub_key.value);
740                 remote_pub_key.value = NULL;
741                 remote_pub_key.length = 0;
742         }
743         sk_free_cred(skc);
744         snd->maj_stat = rc;
745         return -1;
746
747 out_err:
748         snd->maj_stat = rc;
749         if (snd->ctx_token.value) {
750                 free(snd->ctx_token.value);
751                 snd->ctx_token.value = NULL;
752                 snd->ctx_token.length = 0;
753         }
754         if (remote_pub_key.length != 0) {
755                 free(remote_pub_key.value);
756                 remote_pub_key.value = NULL;
757                 remote_pub_key.length = 0;
758         }
759         sk_free_cred(skc);
760         printerr(LL_DEBUG, "sk returning failure\n");
761 #else /* !HAVE_OPENSSL_SSK */
762         printerr(LL_ERR, "ERROR: shared key subflavour is not enabled\n");
763 #endif /* HAVE_OPENSSL_SSK */
764         return -1;
765 }
766
767 int handle_null(struct svc_nego_data *snd)
768 {
769         struct svc_cred cred;
770         uint64_t tmp;
771         uint32_t flags;
772
773         /* null just uses the same token as the return token and for
774          * for sending to the kernel.  It is a single uint64_t. */
775         if (snd->in_tok.length != sizeof(uint64_t)) {
776                 snd->maj_stat = GSS_S_DEFECTIVE_TOKEN;
777                 printerr(LL_ERR, "Invalid token size (%zd) received\n",
778                          snd->in_tok.length);
779                 return -1;
780         }
781         snd->out_tok.length = snd->in_tok.length;
782         snd->out_tok.value = malloc(snd->out_tok.length);
783         if (!snd->out_tok.value) {
784                 snd->maj_stat = GSS_S_FAILURE;
785                 printerr(LL_ERR, "Failed to allocate out_tok\n");
786                 return -1;
787         }
788
789         snd->ctx_token.length = snd->in_tok.length;
790         snd->ctx_token.value = malloc(snd->ctx_token.length);
791         if (!snd->ctx_token.value) {
792                 snd->maj_stat = GSS_S_FAILURE;
793                 printerr(LL_ERR, "Failed to allocate ctx_token\n");
794                 return -1;
795         }
796
797         snd->out_handle.length = sizeof(snd->handle_seq);
798         memcpy(snd->out_handle.value, &snd->handle_seq,
799                sizeof(snd->handle_seq));
800         snd->maj_stat = GSS_S_COMPLETE;
801
802         memcpy(&tmp, snd->in_tok.value, sizeof(tmp));
803         tmp = be64toh(tmp);
804         flags = (uint32_t)(tmp & 0x00000000ffffffff);
805         memset(&cred, 0, sizeof(cred));
806         cred.cr_mapped_uid = -1;
807
808         if (flags & LGSS_ROOT_CRED_ROOT)
809                 cred.cr_usr_root = 1;
810         if (flags & LGSS_ROOT_CRED_MDT)
811                 cred.cr_usr_mds = 1;
812         if (flags & LGSS_ROOT_CRED_OST)
813                 cred.cr_usr_oss = 1;
814
815         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
816
817         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
818         free(snd->ctx_token.value);
819         snd->ctx_token.length = 0;
820
821         return 0;
822 }
823
824 static int handle_krb(struct svc_nego_data *snd)
825 {
826         u_int32_t               ret_flags;
827         gss_name_t              client_name;
828         gss_buffer_desc         ignore_out_tok = {.value = NULL};
829         gss_OID                 mech = GSS_C_NO_OID;
830         gss_cred_id_t           svc_cred;
831         u_int32_t               ignore_min_stat;
832         struct svc_cred         cred;
833
834         svc_cred = gssd_select_svc_cred(snd->lustre_svc);
835         if (!svc_cred) {
836                 printerr(LL_ERR, "no service credential for svc %u\n",
837                          snd->lustre_svc);
838                 goto out_err;
839         }
840
841         snd->maj_stat = gss_accept_sec_context(&snd->min_stat, &snd->ctx,
842                                                svc_cred, &snd->in_tok,
843                                                GSS_C_NO_CHANNEL_BINDINGS,
844                                                &client_name, &mech,
845                                                &snd->out_tok, &ret_flags, NULL,
846                                                NULL);
847
848         if (snd->maj_stat == GSS_S_CONTINUE_NEEDED) {
849                 printerr(LL_WARN,
850                          "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
851
852                 /* Save the context handle for future calls */
853                 snd->out_handle.length = sizeof(snd->ctx);
854                 memcpy(snd->out_handle.value, &snd->ctx, sizeof(snd->ctx));
855                 return 0;
856         } else if (snd->maj_stat != GSS_S_COMPLETE) {
857                 printerr(LL_ERR, "ERROR: gss_accept_sec_context failed\n");
858                 pgsserr("handle_krb: gss_accept_sec_context",
859                         snd->maj_stat, snd->min_stat, mech);
860                 goto out_err;
861         }
862
863         if (get_ids(client_name, mech, &cred, snd->nid, snd->lustre_svc)) {
864                 /* get_ids() prints error msg */
865                 snd->maj_stat = GSS_S_BAD_NAME; /* XXX ? */
866                 gss_release_name(&ignore_min_stat, &client_name);
867                 goto out_err;
868         }
869         gss_release_name(&ignore_min_stat, &client_name);
870
871         /* Context complete. Pass handle_seq in out_handle to use
872          * for context lookup in the kernel. */
873         snd->out_handle.length = sizeof(snd->handle_seq);
874         memcpy(snd->out_handle.value, &snd->handle_seq,
875                sizeof(snd->handle_seq));
876
877         /* kernel needs ctx to calculate verifier on null response, so
878          * must give it context before doing null call: */
879         if (serialize_context_for_kernel(snd->ctx, &snd->ctx_token, mech)) {
880                 printerr(LL_ERR,
881                          "ERROR: %s: serialize_context_for_kernel failed\n",
882                         __func__);
883                 snd->maj_stat = GSS_S_FAILURE;
884                 goto out_err;
885         }
886         /* We no longer need the gss context */
887         gss_delete_sec_context(&ignore_min_stat, &snd->ctx, &ignore_out_tok);
888         do_svc_downcall(&snd->out_handle, &cred, mech, &snd->ctx_token);
889         /* We no longer need the context token */
890         if (snd->ctx_token.value) {
891                 free(snd->ctx_token.value);
892                 snd->ctx_token.value = NULL;
893                 snd->ctx_token.length = 0;
894         }
895         return 0;
896
897 out_err:
898         if (snd->ctx != GSS_C_NO_CONTEXT)
899                 gss_delete_sec_context(&ignore_min_stat, &snd->ctx,
900                                        &ignore_out_tok);
901
902         return 1;
903 }
904
905 int handle_channel_request(int fd)
906 {
907         char in_handle_buf[15];
908         char out_handle_buf[15];
909         uint32_t lustre_mech;
910         static char *lbuf;
911         static int lbuflen;
912         static char *cp;
913         int get_len;
914         int rc;
915         u_int32_t ignore_min_stat;
916         struct svc_nego_data snd = {
917                 .in_tok.value           = NULL,
918                 .in_handle.value        = in_handle_buf,
919                 .out_handle.value       = out_handle_buf,
920                 .maj_stat               = GSS_S_FAILURE,
921                 .ctx                    = GSS_C_NO_CONTEXT,
922         };
923         uint64_t hash = 0;
924
925         printerr(LL_INFO, "handling request\n");
926         if (readline(fd, &lbuf, &lbuflen) != 1) {
927                 printerr(LL_ERR, "ERROR: failed reading request\n");
928                 return -1;
929         }
930
931         cp = lbuf;
932
933         /* see rsi_do_upcall() for the format of data being input here */
934         rc = gss_u64_read_string(&cp, (__u64 *)&hash);
935         if (rc < 0) {
936                 printerr(LL_ERR, "ERROR: failed parsing request: hash\n");
937                 goto out_err;
938         }
939         rc = gss_u64_read_string(&cp, (__u64 *)&snd.lustre_svc);
940         if (rc < 0) {
941                 printerr(LL_ERR, "ERROR: failed parsing request: lustre svc\n");
942                 goto out_err;
943         }
944         /* lustre_svc is the svc and gss subflavor */
945         lustre_mech = (snd.lustre_svc & LUSTRE_GSS_MECH_MASK) >>
946                 LUSTRE_GSS_MECH_SHIFT;
947         snd.lustre_svc = snd.lustre_svc & LUSTRE_GSS_SVC_MASK;
948         switch (lustre_mech) {
949         case LGSS_MECH_KRB5:
950                 if (!krb_enabled) {
951                         static time_t next_krb;
952
953                         if (time(NULL) > next_krb) {
954                                 printerr(LL_WARN,
955                                          "warning: Request for kerberos but service support not enabled\n");
956                                 next_krb = time(NULL) + 3600;
957                         }
958                         goto ignore;
959                 }
960                 snd.mech = &krb5oid;
961                 break;
962         case LGSS_MECH_NULL:
963                 if (!null_enabled) {
964                         static time_t next_null;
965
966                         if (time(NULL) > next_null) {
967                                 printerr(LL_WARN,
968                                          "warning: Request for gssnull but service support not enabled\n");
969                                 next_null = time(NULL) + 3600;
970                         }
971                         goto ignore;
972                 }
973                 snd.mech = &nulloid;
974                 break;
975         case LGSS_MECH_SK:
976                 if (!sk_enabled) {
977                         static time_t next_ssk;
978
979                         if (time(NULL) > next_ssk) {
980                                 printerr(LL_WARN,
981                                          "warning: Request for SSK but service support not %s\n",
982 #ifdef HAVE_OPENSSL_SSK
983                                          "enabled"
984 #else
985                                          "included"
986 #endif
987                                         );
988                                 next_ssk = time(NULL) + 3600;
989                         }
990
991                         goto ignore;
992                 }
993                 snd.mech = &skoid;
994                 break;
995         default:
996                 printerr(LL_ERR, "WARNING: invalid mechanism recevied: %d\n",
997                          lustre_mech);
998                 goto out_err;
999                 break;
1000         }
1001
1002         rc = gss_u64_read_string(&cp, (__u64 *)&snd.nid);
1003         if (rc < 0) {
1004                 printerr(LL_ERR, "ERROR: failed parsing request: source nid\n");
1005                 goto out_err;
1006         }
1007         rc = gss_u64_read_string(&cp, (__u64 *)&snd.handle_seq);
1008         if (rc < 0) {
1009                 printerr(LL_ERR, "ERROR: failed parsing request: handle seq\n");
1010                 goto out_err;
1011         }
1012         get_len = gss_string_read(&cp, snd.nm_name, sizeof(snd.nm_name), 0);
1013         if (get_len <= 0) {
1014                 printerr(LL_ERR,
1015                          "ERROR: failed parsing request: nodemap name\n");
1016                 goto out_err;
1017         }
1018         snd.nm_name[get_len] = '\0';
1019         printerr(LL_INFO,
1020                  "handling req: svc %u, nid %016llx, idx %"PRIx64" nodemap %s\n",
1021                  snd.lustre_svc, snd.nid, snd.handle_seq, snd.nm_name);
1022
1023         get_len = gss_base64url_decode(&cp, snd.in_handle.value,
1024                                        sizeof(in_handle_buf));
1025         if (get_len < 0) {
1026                 printerr(LL_ERR, "ERROR: failed parsing request: in handle\n");
1027                 goto out_err;
1028         }
1029         snd.in_handle.length = (size_t)get_len;
1030
1031         printerr(LL_DEBUG, "in_handle:\n");
1032         print_hexl(3, snd.in_handle.value, snd.in_handle.length);
1033
1034         snd.in_tok.value = malloc(strlen(cp));
1035         if (!snd.in_tok.value) {
1036                 printerr(LL_ERR, "ERROR: failed alloc for in token\n");
1037                 goto out_err;
1038         }
1039         get_len = gss_base64url_decode(&cp, snd.in_tok.value, strlen(cp));
1040         if (get_len < 0) {
1041                 printerr(LL_ERR, "ERROR: failed parsing request: in token\n");
1042                 goto out_err;
1043         }
1044         snd.in_tok.length = (size_t)get_len;
1045
1046         printerr(LL_DEBUG, "in_tok:\n");
1047         print_hexl(3, snd.in_tok.value, snd.in_tok.length);
1048
1049         if (snd.in_handle.length != 0) { /* CONTINUE_INIT case */
1050                 if (snd.in_handle.length != sizeof(snd.ctx)) {
1051                         printerr(LL_ERR,
1052                                  "ERROR: input handle has unexpected length %zu\n",
1053                                  snd.in_handle.length);
1054                         goto out_err;
1055                 }
1056                 /* in_handle is the context id stored in the out_handle
1057                  * for the GSS_S_CONTINUE_NEEDED case below.  */
1058                 memcpy(&snd.ctx, snd.in_handle.value, snd.in_handle.length);
1059         }
1060
1061         rc = -1;
1062         if (lustre_mech == LGSS_MECH_KRB5)
1063                 rc = handle_krb(&snd);
1064         else if (lustre_mech == LGSS_MECH_SK)
1065                 rc = handle_sk(&snd);
1066         else if (lustre_mech == LGSS_MECH_NULL)
1067                 rc = handle_null(&snd);
1068         else
1069                 printerr(LL_ERR,
1070                          "ERROR: Received or request for subflavor that is not enabled: %d\n",
1071                          lustre_mech);
1072
1073 out_err:
1074         /* Failures send a null token */
1075         rc = send_response(rc, hash, &snd.in_handle, &snd.in_tok,
1076                            snd.maj_stat, snd.min_stat,
1077                            &snd.out_handle, &snd.out_tok);
1078
1079         /* cleanup buffers */
1080         if (snd.in_tok.value)
1081                 free(snd.in_tok.value);
1082         if (snd.out_tok.value != NULL)
1083                 gss_release_buffer(&ignore_min_stat, &snd.out_tok);
1084
1085         /* For junk wire data just ignore */
1086 ignore:
1087         return rc;
1088 }