Whamcloud - gitweb
7876c04848408e50b126e1eab6fd38d6be2b29ec
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
1 /*
2   svc_in_gssd_proc.c
3
4   Copyright (c) 2000 The Regents of the University of Michigan.
5   All rights reserved.
6
7   Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8
9   Copyright (c) 2014, Intel Corporation.
10
11   Redistribution and use in source and binary forms, with or without
12   modification, are permitted provided that the following conditions
13   are met:
14
15   1. Redistributions of source code must retain the above copyright
16      notice, this list of conditions and the following disclaimer.
17   2. Redistributions in binary form must reproduce the above copyright
18      notice, this list of conditions and the following disclaimer in the
19      documentation and/or other materials provided with the distribution.
20   3. Neither the name of the University nor the names of its
21      contributors may be used to endorse or promote products derived
22      from this software without specific prior written permission.
23
24   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
25   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
26   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27   DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
28   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
29   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
30   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
31   BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
32   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
33   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
34   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
36 */
37
38 #include <sys/param.h>
39 #include <sys/stat.h>
40
41 #include <pwd.h>
42 #include <stdio.h>
43 #include <unistd.h>
44 #include <ctype.h>
45 #include <string.h>
46 #include <fcntl.h>
47 #include <errno.h>
48 #include <netdb.h>
49
50 #include "svcgssd.h"
51 #include "gss_util.h"
52 #include "err_util.h"
53 #include "context.h"
54 #include "cacheio.h"
55 #include "lsupport.h"
56
57 extern char * mech2file(gss_OID mech);
58 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.sptlrpc.context/channel"
59 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.sptlrpc.init/channel"
60
61 #define TOKEN_BUF_SIZE          8192
62
63 struct svc_cred {
64         uint32_t cr_remote;
65         uint32_t cr_usr_root;
66         uint32_t cr_usr_mds;
67         uint32_t cr_usr_oss;
68         uid_t    cr_uid;
69         uid_t    cr_mapped_uid;
70         uid_t    cr_gid;
71 };
72
73 static int
74 do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
75                 gss_OID mech, gss_buffer_desc *context_token)
76 {
77         FILE *f;
78         char *fname = NULL;
79         int err;
80
81         printerr(2, "doing downcall\n");
82         if ((fname = mech2file(mech)) == NULL)
83                 goto out_err;
84         f = fopen(SVCGSSD_CONTEXT_CHANNEL, "w");
85         if (f == NULL) {
86                 printerr(0, "WARNING: unable to open downcall channel "
87                              "%s: %s\n",
88                              SVCGSSD_CONTEXT_CHANNEL, strerror(errno));
89                 goto out_err;
90         }
91         qword_printhex(f, out_handle->value, out_handle->length);
92         /* XXX are types OK for the rest of this? */
93         qword_printint(f, 0x7fffffff); /*XXX need a better timeout */
94         qword_printint(f, cred->cr_remote);
95         qword_printint(f, cred->cr_usr_root);
96         qword_printint(f, cred->cr_usr_mds);
97         qword_printint(f, cred->cr_usr_oss);
98         qword_printint(f, cred->cr_mapped_uid);
99         qword_printint(f, cred->cr_uid);
100         qword_printint(f, cred->cr_gid);
101         qword_print(f, fname);
102         qword_printhex(f, context_token->value, context_token->length);
103         err = qword_eol(f);
104         fclose(f);
105         return err;
106 out_err:
107         printerr(0, "WARNING: downcall failed\n");
108         return -1;
109 }
110
111 struct gss_verifier {
112         u_int32_t       flav;
113         gss_buffer_desc body;
114 };
115
116 #define RPCSEC_GSS_SEQ_WIN      5
117
118 static int
119 send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
120               u_int32_t maj_stat, u_int32_t min_stat,
121               gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
122 {
123         char buf[2 * TOKEN_BUF_SIZE];
124         char *bp = buf;
125         int blen = sizeof(buf);
126         /* XXXARG: */
127         int g;
128
129         printerr(2, "sending null reply\n");
130
131         qword_addhex(&bp, &blen, in_handle->value, in_handle->length);
132         qword_addhex(&bp, &blen, in_token->value, in_token->length);
133         qword_addint(&bp, &blen, 0x7fffffff); /*XXX need a better timeout */
134         qword_adduint(&bp, &blen, maj_stat);
135         qword_adduint(&bp, &blen, min_stat);
136         qword_addhex(&bp, &blen, out_handle->value, out_handle->length);
137         qword_addhex(&bp, &blen, out_token->value, out_token->length);
138         qword_addeol(&bp, &blen);
139         if (blen <= 0) {
140                 printerr(0, "WARNING: send_respsonse: message too long\n");
141                 return -1;
142         }
143         g = open(SVCGSSD_INIT_CHANNEL, O_WRONLY);
144         if (g == -1) {
145                 printerr(0, "WARNING: open %s failed: %s\n",
146                                 SVCGSSD_INIT_CHANNEL, strerror(errno));
147                 return -1;
148         }
149         *bp = '\0';
150         printerr(3, "writing message: %s", buf);
151         if (write(g, buf, bp - buf) == -1) {
152                 printerr(0, "WARNING: failed to write message\n");
153                 close(g);
154                 return -1;
155         }
156         close(g);
157         return 0;
158 }
159
160 #define rpc_auth_ok                     0
161 #define rpc_autherr_badcred             1
162 #define rpc_autherr_rejectedcred        2
163 #define rpc_autherr_badverf             3
164 #define rpc_autherr_rejectedverf        4
165 #define rpc_autherr_tooweak             5
166 #define rpcsec_gsserr_credproblem       13
167 #define rpcsec_gsserr_ctxproblem        14
168
169 #if 0
170 static void
171 add_supplementary_groups(char *secname, char *name, struct svc_cred *cred)
172 {
173         int ret;
174         static gid_t *groups = NULL;
175
176         cred->cr_ngroups = NGROUPS;
177         ret = nfs4_gss_princ_to_grouplist(secname, name,
178                         cred->cr_groups, &cred->cr_ngroups);
179         if (ret < 0) {
180                 groups = realloc(groups, cred->cr_ngroups*sizeof(gid_t));
181                 ret = nfs4_gss_princ_to_grouplist(secname, name,
182                                 groups, &cred->cr_ngroups);
183                 if (ret < 0)
184                         cred->cr_ngroups = 0;
185                 else {
186                         if (cred->cr_ngroups > NGROUPS)
187                                 cred->cr_ngroups = NGROUPS;
188                         memcpy(cred->cr_groups, groups,
189                                         cred->cr_ngroups*sizeof(gid_t));
190                 }
191         }
192 }
193 #endif
194
195 #if 0
196 static int
197 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
198 {
199         u_int32_t       maj_stat, min_stat;
200         gss_buffer_desc name;
201         char            *sname;
202         int             res = -1;
203         uid_t           uid, gid;
204         gss_OID         name_type = GSS_C_NO_OID;
205         char            *secname;
206
207         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
208         if (maj_stat != GSS_S_COMPLETE) {
209                 pgsserr("get_ids: gss_display_name",
210                         maj_stat, min_stat, mech);
211                 goto out;
212         }
213         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
214             !(sname = calloc(name.length + 1, 1))) {
215                 printerr(0, "WARNING: get_ids: error allocating %d bytes "
216                         "for sname\n", name.length + 1);
217                 gss_release_buffer(&min_stat, &name);
218                 goto out;
219         }
220         memcpy(sname, name.value, name.length);
221         printerr(1, "sname = %s\n", sname);
222         gss_release_buffer(&min_stat, &name);
223
224         res = -EINVAL;
225         if ((secname = mech2file(mech)) == NULL) {
226                 printerr(0, "WARNING: get_ids: error mapping mech to "
227                         "file for name '%s'\n", sname);
228                 goto out_free;
229         }
230         nfs4_init_name_mapping(NULL); /* XXX: should only do this once */
231         res = nfs4_gss_princ_to_ids(secname, sname, &uid, &gid);
232         if (res < 0) {
233                 /*
234                  * -ENOENT means there was no mapping, any other error
235                  * value means there was an error trying to do the
236                  * mapping.
237                  * If there was no mapping, we send down the value -1
238                  * to indicate that the anonuid/anongid for the export
239                  * should be used.
240                  */
241                 if (res == -ENOENT) {
242                         cred->cr_uid = -1;
243                         cred->cr_gid = -1;
244                         cred->cr_ngroups = 0;
245                         res = 0;
246                         goto out_free;
247                 }
248                 printerr(0, "WARNING: get_ids: failed to map name '%s' "
249                         "to uid/gid: %s\n", sname, strerror(-res));
250                 goto out_free;
251         }
252         cred->cr_uid = uid;
253         cred->cr_gid = gid;
254         add_supplementary_groups(secname, sname, cred);
255         res = 0;
256 out_free:
257         free(sname);
258 out:
259         return res;
260 }
261 #endif
262
263 #if 0
264 void
265 print_hexl(int pri, unsigned char *cp, int length)
266 {
267         int i, j, jm;
268         unsigned char c;
269
270         printerr(pri, "length %d\n",length);
271         printerr(pri, "\n");
272
273         for (i = 0; i < length; i += 0x10) {
274                 printerr(pri, "  %04x: ", (u_int)i);
275                 jm = length - i;
276                 jm = jm > 16 ? 16 : jm;
277
278                 for (j = 0; j < jm; j++) {
279                         if ((j % 2) == 1)
280                                 printerr(pri,"%02x ", (u_int)cp[i+j]);
281                         else
282                                 printerr(pri,"%02x", (u_int)cp[i+j]);
283                 }
284                 for (; j < 16; j++) {
285                         if ((j % 2) == 1)
286                                 printerr(pri,"   ");
287                         else
288                                 printerr(pri,"  ");
289                 }
290                 printerr(pri," ");
291
292                 for (j = 0; j < jm; j++) {
293                         c = cp[i+j];
294                         c = isprint(c) ? c : '.';
295                         printerr(pri,"%c", c);
296                 }
297                 printerr(pri,"\n");
298         }
299 }
300 #endif
301
302 static int
303 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred,
304         lnet_nid_t nid, uint32_t lustre_svc)
305 {
306         u_int32_t       maj_stat, min_stat;
307         gss_buffer_desc name;
308         char            *sname, *host, *realm;
309         const int       namebuf_size = 512;
310         char            namebuf[namebuf_size];
311         int             res = -1;
312         gss_OID         name_type = GSS_C_NO_OID;
313         struct passwd   *pw;
314
315         cred->cr_remote = 0;
316         cred->cr_usr_root = cred->cr_usr_mds = cred->cr_usr_oss = 0;
317         cred->cr_uid = cred->cr_mapped_uid = cred->cr_gid = -1;
318
319         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
320         if (maj_stat != GSS_S_COMPLETE) {
321                 pgsserr("get_ids: gss_display_name",
322                         maj_stat, min_stat, mech);
323                 return -1;
324         }
325         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
326             !(sname = calloc(name.length + 1, 1))) {
327                 printerr(0, "WARNING: get_ids: error allocating %d bytes "
328                         "for sname\n", name.length + 1);
329                 gss_release_buffer(&min_stat, &name);
330                 return -1;
331         }
332         memcpy(sname, name.value, name.length);
333         sname[name.length] = '\0';
334         gss_release_buffer(&min_stat, &name);
335
336         if (lustre_svc == LUSTRE_GSS_SVC_MDS)
337                 lookup_mapping(sname, nid, &cred->cr_mapped_uid);
338         else
339                 cred->cr_mapped_uid = -1;
340
341         realm = strchr(sname, '@');
342         if (realm) {
343                 *realm++ = '\0';
344         } else {
345                 printerr(0, "ERROR: %s has no realm name\n", sname);
346                 goto out_free;
347         }
348
349         host = strchr(sname, '/');
350         if (host)
351                 *host++ = '\0';
352
353         if (strcmp(sname, GSSD_SERVICE_MGS) == 0) {
354                 printerr(0, "forbid %s as a user name\n", sname);
355                 goto out_free;
356         }
357
358         /* 1. check host part */
359         if (host) {
360                 if (lnet_nid2hostname(nid, namebuf, namebuf_size)) {
361                         printerr(0, "ERROR: failed to resolve hostname for "
362                                  "%s/%s@%s from %016llx\n",
363                                  sname, host, realm, nid);
364                         goto out_free;
365                 }
366
367                 if (strcasecmp(host, namebuf)) {
368                         printerr(0, "ERROR: %s/%s@%s claimed hostname doesn't "
369                                  "match %s, nid %016llx\n", sname, host, realm,
370                                  namebuf, nid);
371                         goto out_free;
372                 }
373         } else {
374                 if (!strcmp(sname, GSSD_SERVICE_MDS) ||
375                     !strcmp(sname, GSSD_SERVICE_OSS)) {
376                         printerr(0, "ERROR: %s@%s from %016llx doesn't "
377                                  "bind with hostname\n", sname, realm, nid);
378                         goto out_free;
379                 }
380         }
381
382         /* 2. check realm and user */
383         switch (lustre_svc) {
384         case LUSTRE_GSS_SVC_MDS:
385                 if (strcasecmp(mds_local_realm, realm)) {
386                         cred->cr_remote = 1;
387
388                         /* only allow mapped user from remote realm */
389                         if (cred->cr_mapped_uid == -1) {
390                                 printerr(0, "ERROR: %s%s%s@%s from %016llx "
391                                          "is remote but without mapping\n",
392                                          sname, host ? "/" : "",
393                                          host ? host : "", realm, nid);
394                                 break;
395                         }
396                 } else {
397                         if (!strcmp(sname, LUSTRE_ROOT_NAME)) {
398                                 cred->cr_uid = 0;
399                                 cred->cr_usr_root = 1;
400                         } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
401                                 cred->cr_uid = 0;
402                                 cred->cr_usr_mds = 1;
403                         } else if (!strcmp(sname, GSSD_SERVICE_OSS)) {
404                                 printerr(0, "ERROR: MDS doesn't accept "
405                                          "user "GSSD_SERVICE_OSS"\n");
406                                 break;
407                         } else {
408                                 pw = getpwnam(sname);
409                                 if (pw != NULL) {
410                                         cred->cr_uid = pw->pw_uid;
411                                         printerr(2, "%s resolve to uid %u\n",
412                                                  sname, cred->cr_uid);
413                                 } else if (cred->cr_mapped_uid != -1) {
414                                         printerr(2, "user %s from %016llx is "
415                                                  "mapped to %u\n", sname, nid,
416                                                  cred->cr_mapped_uid);
417                                 } else {
418                                         printerr(0, "ERROR: invalid user, "
419                                                  "%s/%s@%s from %016llx\n",
420                                                  sname, host, realm, nid);
421                                         break;
422                                 }
423                         }
424                 }
425
426                 res = 0;
427                 break;
428         case LUSTRE_GSS_SVC_MGS:
429                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
430                         cred->cr_uid = 0;
431                         cred->cr_usr_oss = 1;
432                 }
433                 /* fall through */
434         case LUSTRE_GSS_SVC_OSS:
435                 if (!strcmp(sname, LUSTRE_ROOT_NAME)) {
436                         cred->cr_uid = 0;
437                         cred->cr_usr_root = 1;
438                 } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
439                         cred->cr_uid = 0;
440                         cred->cr_usr_mds = 1;
441                 } else {
442                         printerr(0, "ERROR: svc %d doesn't accept user %s"
443                                  "from %016llx\n", lustre_svc, sname, nid);
444                         break;
445                 }
446                 res = 0;
447                 break;
448         default:
449                 assert(0);
450         }
451
452 out_free:
453         if (!res)
454                 printerr(1, "%s: authenticated %s%s%s@%s from %016llx\n",
455                          lustre_svc_name[lustre_svc], sname,
456                          host ? "/" : "", host ? host : "", realm, nid);
457         free(sname);
458         return res;
459 }
460
461 typedef struct gss_union_ctx_id_t {
462         gss_OID         mech_type;
463         gss_ctx_id_t    internal_ctx_id;
464 } gss_union_ctx_id_desc, *gss_union_ctx_id_t;
465
466 /*
467  * return -1 only if we detect error during reading from upcall channel,
468  * all other cases return 0.
469  */
470 int
471 handle_nullreq(FILE *f) {
472         uint64_t                handle_seq;
473         char                    in_tok_buf[TOKEN_BUF_SIZE];
474         char                    in_handle_buf[15];
475         char                    out_handle_buf[15];
476         gss_buffer_desc         in_tok = {.value = in_tok_buf},
477                                 out_tok = {.value = NULL},
478                                 in_handle = {.value = in_handle_buf},
479                                 out_handle = {.value = out_handle_buf},
480                                 ctx_token = {.value = NULL},
481                                 ignore_out_tok = {.value = NULL},
482         /* XXX isn't there a define for this?: */
483                                 null_token = {.value = NULL};
484         uint32_t                lustre_svc;
485         lnet_nid_t              nid;
486         u_int32_t               ret_flags;
487         gss_ctx_id_t            ctx = GSS_C_NO_CONTEXT;
488         gss_name_t              client_name;
489         gss_OID                 mech = GSS_C_NO_OID;
490         gss_cred_id_t           svc_cred;
491         u_int32_t               maj_stat = GSS_S_FAILURE, min_stat = 0;
492         u_int32_t               ignore_min_stat;
493         int                     get_len;
494         struct svc_cred         cred;
495         static char             *lbuf = NULL;
496         static int              lbuflen = 0;
497         static char             *cp;
498
499         printerr(2, "handling null request\n");
500
501         if (readline(fileno(f), &lbuf, &lbuflen) != 1) {
502                 printerr(0, "WARNING: handle_nullreq: "
503                             "failed reading request\n");
504                 return -1;
505         }
506
507         cp = lbuf;
508
509         qword_get(&cp, (char *) &lustre_svc, sizeof(lustre_svc));
510         qword_get(&cp, (char *) &nid, sizeof(nid));
511         qword_get(&cp, (char *) &handle_seq, sizeof(handle_seq));
512         printerr(2, "handling req: svc %u, nid %016llx, idx %llx\n",
513                  lustre_svc, nid, handle_seq);
514
515         get_len = qword_get(&cp, in_handle.value, sizeof(in_handle_buf));
516         if (get_len < 0) {
517                 printerr(0, "WARNING: handle_nullreq: "
518                             "failed parsing request\n");
519                 goto out_err;
520         }
521         in_handle.length = (size_t)get_len;
522
523         printerr(3, "in_handle:\n");
524         print_hexl(3, in_handle.value, in_handle.length);
525
526         get_len = qword_get(&cp, in_tok.value, sizeof(in_tok_buf));
527         if (get_len < 0) {
528                 printerr(0, "WARNING: handle_nullreq: "
529                             "failed parsing request\n");
530                 goto out_err;
531         }
532         in_tok.length = (size_t)get_len;
533
534         printerr(3, "in_tok:\n");
535         print_hexl(3, in_tok.value, in_tok.length);
536
537         if (in_handle.length != 0) { /* CONTINUE_INIT case */
538                 if (in_handle.length != sizeof(ctx)) {
539                         printerr(0, "WARNING: handle_nullreq: "
540                                     "input handle has unexpected length %d\n",
541                                     in_handle.length);
542                         goto out_err;
543                 }
544                 /* in_handle is the context id stored in the out_handle
545                  * for the GSS_S_CONTINUE_NEEDED case below.  */
546                 memcpy(&ctx, in_handle.value, in_handle.length);
547         }
548
549         svc_cred = gssd_select_svc_cred(lustre_svc);
550         if (!svc_cred) {
551                 printerr(0, "no service credential for svc %u\n", lustre_svc);
552                 goto out_err;
553         }
554
555         maj_stat = gss_accept_sec_context(&min_stat, &ctx, svc_cred,
556                         &in_tok, GSS_C_NO_CHANNEL_BINDINGS, &client_name,
557                         &mech, &out_tok, &ret_flags, NULL, NULL);
558
559         if (maj_stat == GSS_S_CONTINUE_NEEDED) {
560                 printerr(1, "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
561
562                 /* Save the context handle for future calls */
563                 out_handle.length = sizeof(ctx);
564                 memcpy(out_handle.value, &ctx, sizeof(ctx));
565                 goto continue_needed;
566         }
567         else if (maj_stat != GSS_S_COMPLETE) {
568                 printerr(0, "WARNING: gss_accept_sec_context failed\n");
569                 pgsserr("handle_nullreq: gss_accept_sec_context",
570                         maj_stat, min_stat, mech);
571                 goto out_err;
572         }
573
574         if (get_ids(client_name, mech, &cred, nid, lustre_svc)) {
575                 /* get_ids() prints error msg */
576                 maj_stat = GSS_S_BAD_NAME; /* XXX ? */
577                 gss_release_name(&ignore_min_stat, &client_name);
578                 goto out_err;
579         }
580         gss_release_name(&ignore_min_stat, &client_name);
581
582         /* Context complete. Pass handle_seq in out_handle to use
583          * for context lookup in the kernel. */
584         out_handle.length = sizeof(handle_seq);
585         memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
586
587         /* kernel needs ctx to calculate verifier on null response, so
588          * must give it context before doing null call: */
589         if (serialize_context_for_kernel(ctx, &ctx_token, mech)) {
590                 printerr(0, "WARNING: handle_nullreq: "
591                             "serialize_context_for_kernel failed\n");
592                 maj_stat = GSS_S_FAILURE;
593                 goto out_err;
594         }
595         /* We no longer need the gss context */
596         gss_delete_sec_context(&ignore_min_stat, &ctx, &ignore_out_tok);
597
598         do_svc_downcall(&out_handle, &cred, mech, &ctx_token);
599 continue_needed:
600         send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
601                         &out_handle, &out_tok);
602 out:
603         if (ctx_token.value != NULL)
604                 free(ctx_token.value);
605         if (out_tok.value != NULL)
606                 gss_release_buffer(&ignore_min_stat, &out_tok);
607         return 0;
608
609 out_err:
610         if (ctx != GSS_C_NO_CONTEXT)
611                 gss_delete_sec_context(&ignore_min_stat, &ctx, &ignore_out_tok);
612         send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
613                         &null_token, &null_token);
614         goto out;
615 }