Whamcloud - gitweb
a0ffa52d744c05ed8ebf09d9076e0655068dfe33
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
1 /*
2   svc_in_gssd_proc.c
3
4   Copyright (c) 2000 The Regents of the University of Michigan.
5   All rights reserved.
6
7   Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8
9   Redistribution and use in source and binary forms, with or without
10   modification, are permitted provided that the following conditions
11   are met:
12
13   1. Redistributions of source code must retain the above copyright
14      notice, this list of conditions and the following disclaimer.
15   2. Redistributions in binary form must reproduce the above copyright
16      notice, this list of conditions and the following disclaimer in the
17      documentation and/or other materials provided with the distribution.
18   3. Neither the name of the University nor the names of its
19      contributors may be used to endorse or promote products derived
20      from this software without specific prior written permission.
21
22   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
23   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25   DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
26   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29   BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34 */
35
36 #include <sys/param.h>
37 #include <sys/stat.h>
38
39 #include <pwd.h>
40 #include <stdio.h>
41 #include <unistd.h>
42 #include <ctype.h>
43 #include <string.h>
44 #include <fcntl.h>
45 #include <errno.h>
46 #include <netdb.h>
47
48 #include "svcgssd.h"
49 #include "gss_util.h"
50 #include "err_util.h"
51 #include "context.h"
52 #include "cacheio.h"
53 #include "lsupport.h"
54
55 extern char * mech2file(gss_OID mech);
56 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.sptlrpc.context/channel"
57 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.sptlrpc.init/channel"
58
59 #define TOKEN_BUF_SIZE          8192
60
61 struct svc_cred {
62         uint32_t cr_remote;
63         uint32_t cr_usr_root;
64         uint32_t cr_usr_mds;
65         uint32_t cr_usr_oss;
66         uid_t    cr_uid;
67         uid_t    cr_mapped_uid;
68         uid_t    cr_gid;
69 };
70
71 static int
72 do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
73                 gss_OID mech, gss_buffer_desc *context_token)
74 {
75         FILE *f;
76         char *fname = NULL;
77         int err;
78
79         printerr(2, "doing downcall\n");
80         if ((fname = mech2file(mech)) == NULL)
81                 goto out_err;
82         f = fopen(SVCGSSD_CONTEXT_CHANNEL, "w");
83         if (f == NULL) {
84                 printerr(0, "WARNING: unable to open downcall channel "
85                              "%s: %s\n",
86                              SVCGSSD_CONTEXT_CHANNEL, strerror(errno));
87                 goto out_err;
88         }
89         qword_printhex(f, out_handle->value, out_handle->length);
90         /* XXX are types OK for the rest of this? */
91         qword_printint(f, 0x7fffffff); /*XXX need a better timeout */
92         qword_printint(f, cred->cr_remote);
93         qword_printint(f, cred->cr_usr_root);
94         qword_printint(f, cred->cr_usr_mds);
95         qword_printint(f, cred->cr_usr_oss);
96         qword_printint(f, cred->cr_mapped_uid);
97         qword_printint(f, cred->cr_uid);
98         qword_printint(f, cred->cr_gid);
99         qword_print(f, fname);
100         qword_printhex(f, context_token->value, context_token->length);
101         err = qword_eol(f);
102         fclose(f);
103         return err;
104 out_err:
105         printerr(0, "WARNING: downcall failed\n");
106         return -1;
107 }
108
109 struct gss_verifier {
110         u_int32_t       flav;
111         gss_buffer_desc body;
112 };
113
114 #define RPCSEC_GSS_SEQ_WIN      5
115
116 static int
117 send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
118               u_int32_t maj_stat, u_int32_t min_stat,
119               gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
120 {
121         char buf[2 * TOKEN_BUF_SIZE];
122         char *bp = buf;
123         int blen = sizeof(buf);
124         /* XXXARG: */
125         int g;
126
127         printerr(2, "sending null reply\n");
128
129         qword_addhex(&bp, &blen, in_handle->value, in_handle->length);
130         qword_addhex(&bp, &blen, in_token->value, in_token->length);
131         qword_addint(&bp, &blen, 0x7fffffff); /*XXX need a better timeout */
132         qword_adduint(&bp, &blen, maj_stat);
133         qword_adduint(&bp, &blen, min_stat);
134         qword_addhex(&bp, &blen, out_handle->value, out_handle->length);
135         qword_addhex(&bp, &blen, out_token->value, out_token->length);
136         qword_addeol(&bp, &blen);
137         if (blen <= 0) {
138                 printerr(0, "WARNING: send_respsonse: message too long\n");
139                 return -1;
140         }
141         g = open(SVCGSSD_INIT_CHANNEL, O_WRONLY);
142         if (g == -1) {
143                 printerr(0, "WARNING: open %s failed: %s\n",
144                                 SVCGSSD_INIT_CHANNEL, strerror(errno));
145                 return -1;
146         }
147         *bp = '\0';
148         printerr(3, "writing message: %s", buf);
149         if (write(g, buf, bp - buf) == -1) {
150                 printerr(0, "WARNING: failed to write message\n");
151                 close(g);
152                 return -1;
153         }
154         close(g);
155         return 0;
156 }
157
158 #define rpc_auth_ok                     0
159 #define rpc_autherr_badcred             1
160 #define rpc_autherr_rejectedcred        2
161 #define rpc_autherr_badverf             3
162 #define rpc_autherr_rejectedverf        4
163 #define rpc_autherr_tooweak             5
164 #define rpcsec_gsserr_credproblem       13
165 #define rpcsec_gsserr_ctxproblem        14
166
167 #if 0
168 static void
169 add_supplementary_groups(char *secname, char *name, struct svc_cred *cred)
170 {
171         int ret;
172         static gid_t *groups = NULL;
173
174         cred->cr_ngroups = NGROUPS;
175         ret = nfs4_gss_princ_to_grouplist(secname, name,
176                         cred->cr_groups, &cred->cr_ngroups);
177         if (ret < 0) {
178                 groups = realloc(groups, cred->cr_ngroups*sizeof(gid_t));
179                 ret = nfs4_gss_princ_to_grouplist(secname, name,
180                                 groups, &cred->cr_ngroups);
181                 if (ret < 0)
182                         cred->cr_ngroups = 0;
183                 else {
184                         if (cred->cr_ngroups > NGROUPS)
185                                 cred->cr_ngroups = NGROUPS;
186                         memcpy(cred->cr_groups, groups,
187                                         cred->cr_ngroups*sizeof(gid_t));
188                 }
189         }
190 }
191 #endif
192
193 #if 0
194 static int
195 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
196 {
197         u_int32_t       maj_stat, min_stat;
198         gss_buffer_desc name;
199         char            *sname;
200         int             res = -1;
201         uid_t           uid, gid;
202         gss_OID         name_type = GSS_C_NO_OID;
203         char            *secname;
204
205         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
206         if (maj_stat != GSS_S_COMPLETE) {
207                 pgsserr("get_ids: gss_display_name",
208                         maj_stat, min_stat, mech);
209                 goto out;
210         }
211         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
212             !(sname = calloc(name.length + 1, 1))) {
213                 printerr(0, "WARNING: get_ids: error allocating %d bytes "
214                         "for sname\n", name.length + 1);
215                 gss_release_buffer(&min_stat, &name);
216                 goto out;
217         }
218         memcpy(sname, name.value, name.length);
219         printerr(1, "sname = %s\n", sname);
220         gss_release_buffer(&min_stat, &name);
221
222         res = -EINVAL;
223         if ((secname = mech2file(mech)) == NULL) {
224                 printerr(0, "WARNING: get_ids: error mapping mech to "
225                         "file for name '%s'\n", sname);
226                 goto out_free;
227         }
228         nfs4_init_name_mapping(NULL); /* XXX: should only do this once */
229         res = nfs4_gss_princ_to_ids(secname, sname, &uid, &gid);
230         if (res < 0) {
231                 /*
232                  * -ENOENT means there was no mapping, any other error
233                  * value means there was an error trying to do the
234                  * mapping.
235                  * If there was no mapping, we send down the value -1
236                  * to indicate that the anonuid/anongid for the export
237                  * should be used.
238                  */
239                 if (res == -ENOENT) {
240                         cred->cr_uid = -1;
241                         cred->cr_gid = -1;
242                         cred->cr_ngroups = 0;
243                         res = 0;
244                         goto out_free;
245                 }
246                 printerr(0, "WARNING: get_ids: failed to map name '%s' "
247                         "to uid/gid: %s\n", sname, strerror(-res));
248                 goto out_free;
249         }
250         cred->cr_uid = uid;
251         cred->cr_gid = gid;
252         add_supplementary_groups(secname, sname, cred);
253         res = 0;
254 out_free:
255         free(sname);
256 out:
257         return res;
258 }
259 #endif
260
261 #if 0
262 void
263 print_hexl(int pri, unsigned char *cp, int length)
264 {
265         int i, j, jm;
266         unsigned char c;
267
268         printerr(pri, "length %d\n",length);
269         printerr(pri, "\n");
270
271         for (i = 0; i < length; i += 0x10) {
272                 printerr(pri, "  %04x: ", (u_int)i);
273                 jm = length - i;
274                 jm = jm > 16 ? 16 : jm;
275
276                 for (j = 0; j < jm; j++) {
277                         if ((j % 2) == 1)
278                                 printerr(pri,"%02x ", (u_int)cp[i+j]);
279                         else
280                                 printerr(pri,"%02x", (u_int)cp[i+j]);
281                 }
282                 for (; j < 16; j++) {
283                         if ((j % 2) == 1)
284                                 printerr(pri,"   ");
285                         else
286                                 printerr(pri,"  ");
287                 }
288                 printerr(pri," ");
289
290                 for (j = 0; j < jm; j++) {
291                         c = cp[i+j];
292                         c = isprint(c) ? c : '.';
293                         printerr(pri,"%c", c);
294                 }
295                 printerr(pri,"\n");
296         }
297 }
298 #endif
299
300 static int
301 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred,
302         lnet_nid_t nid, uint32_t lustre_svc)
303 {
304         u_int32_t       maj_stat, min_stat;
305         gss_buffer_desc name;
306         char            *sname, *host, *realm;
307         const int       namebuf_size = 512;
308         char            namebuf[namebuf_size];
309         int             res = -1;
310         gss_OID         name_type = GSS_C_NO_OID;
311         struct passwd   *pw;
312
313         cred->cr_remote = 0;
314         cred->cr_usr_root = cred->cr_usr_mds = cred->cr_usr_oss = 0;
315         cred->cr_uid = cred->cr_mapped_uid = cred->cr_gid = -1;
316
317         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
318         if (maj_stat != GSS_S_COMPLETE) {
319                 pgsserr("get_ids: gss_display_name",
320                         maj_stat, min_stat, mech);
321                 return -1;
322         }
323         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
324             !(sname = calloc(name.length + 1, 1))) {
325                 printerr(0, "WARNING: get_ids: error allocating %d bytes "
326                         "for sname\n", name.length + 1);
327                 gss_release_buffer(&min_stat, &name);
328                 return -1;
329         }
330         memcpy(sname, name.value, name.length);
331         sname[name.length] = '\0';
332         gss_release_buffer(&min_stat, &name);
333
334         if (lustre_svc == LUSTRE_GSS_SVC_MDS)
335                 lookup_mapping(sname, nid, &cred->cr_mapped_uid);
336         else
337                 cred->cr_mapped_uid = -1;
338
339         realm = strchr(sname, '@');
340         if (realm) {
341                 *realm++ = '\0';
342         } else {
343                 printerr(0, "ERROR: %s has no realm name\n", sname);
344                 goto out_free;
345         }
346
347         host = strchr(sname, '/');
348         if (host)
349                 *host++ = '\0';
350
351         if (strcmp(sname, GSSD_SERVICE_MGS) == 0) {
352                 printerr(0, "forbid %s as a user name\n", sname);
353                 goto out_free;
354         }
355
356         /* 1. check host part */
357         if (host) {
358                 if (lnet_nid2hostname(nid, namebuf, namebuf_size)) {
359                         printerr(0, "ERROR: failed to resolve hostname for "
360                                  "%s/%s@%s from %016llx\n",
361                                  sname, host, realm, nid);
362                         goto out_free;
363                 }
364
365                 if (strcasecmp(host, namebuf)) {
366                         printerr(0, "ERROR: %s/%s@%s claimed hostname doesn't "
367                                  "match %s, nid %016llx\n", sname, host, realm,
368                                  namebuf, nid);
369                         goto out_free;
370                 }
371         } else {
372                 if (!strcmp(sname, GSSD_SERVICE_MDS) ||
373                     !strcmp(sname, GSSD_SERVICE_OSS)) {
374                         printerr(0, "ERROR: %s@%s from %016llx doesn't "
375                                  "bind with hostname\n", sname, realm, nid);
376                         goto out_free;
377                 }
378         }
379
380         /* 2. check realm and user */
381         switch (lustre_svc) {
382         case LUSTRE_GSS_SVC_MDS:
383                 if (strcasecmp(mds_local_realm, realm)) {
384                         cred->cr_remote = 1;
385
386                         /* only allow mapped user from remote realm */
387                         if (cred->cr_mapped_uid == -1) {
388                                 printerr(0, "ERROR: %s%s%s@%s from %016llx "
389                                          "is remote but without mapping\n",
390                                          sname, host ? "/" : "",
391                                          host ? host : "", realm, nid);
392                                 break;
393                         }
394                 } else {
395                         if (!strcmp(sname, LUSTRE_ROOT_NAME)) {
396                                 cred->cr_uid = 0;
397                                 cred->cr_usr_root = 1;
398                         } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
399                                 cred->cr_uid = 0;
400                                 cred->cr_usr_mds = 1;
401                         } else if (!strcmp(sname, GSSD_SERVICE_OSS)) {
402                                 printerr(0, "ERROR: MDS doesn't accept "
403                                          "user "GSSD_SERVICE_OSS"\n");
404                                 break;
405                         } else {
406                                 pw = getpwnam(sname);
407                                 if (pw != NULL) {
408                                         cred->cr_uid = pw->pw_uid;
409                                         printerr(2, "%s resolve to uid %u\n",
410                                                  sname, cred->cr_uid);
411                                 } else if (cred->cr_mapped_uid != -1) {
412                                         printerr(2, "user %s from %016llx is "
413                                                  "mapped to %u\n", sname, nid,
414                                                  cred->cr_mapped_uid);
415                                 } else {
416                                         printerr(0, "ERROR: invalid user, "
417                                                  "%s/%s@%s from %016llx\n",
418                                                  sname, host, realm, nid);
419                                         break;
420                                 }
421                         }
422                 }
423
424                 res = 0;
425                 break;
426         case LUSTRE_GSS_SVC_MGS:
427                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
428                         cred->cr_uid = 0;
429                         cred->cr_usr_oss = 1;
430                 }
431                 /* fall through */
432         case LUSTRE_GSS_SVC_OSS:
433                 if (!strcmp(sname, LUSTRE_ROOT_NAME)) {
434                         cred->cr_uid = 0;
435                         cred->cr_usr_root = 1;
436                 } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
437                         cred->cr_uid = 0;
438                         cred->cr_usr_mds = 1;
439                 } else {
440                         printerr(0, "ERROR: svc %d doesn't accept user %s"
441                                  "from %016llx\n", lustre_svc, sname, nid);
442                         break;
443                 }
444                 res = 0;
445                 break;
446         default:
447                 assert(0);
448         }
449
450 out_free:
451         if (!res)
452                 printerr(1, "%s: authenticated %s%s%s@%s from %016llx\n",
453                          lustre_svc_name[lustre_svc], sname,
454                          host ? "/" : "", host ? host : "", realm, nid);
455         free(sname);
456         return res;
457 }
458
459 typedef struct gss_union_ctx_id_t {
460         gss_OID         mech_type;
461         gss_ctx_id_t    internal_ctx_id;
462 } gss_union_ctx_id_desc, *gss_union_ctx_id_t;
463
464 /*
465  * return -1 only if we detect error during reading from upcall channel,
466  * all other cases return 0.
467  */
468 int
469 handle_nullreq(FILE *f) {
470         uint64_t                handle_seq;
471         char                    in_tok_buf[TOKEN_BUF_SIZE];
472         char                    in_handle_buf[15];
473         char                    out_handle_buf[15];
474         gss_buffer_desc         in_tok = {.value = in_tok_buf},
475                                 out_tok = {.value = NULL},
476                                 in_handle = {.value = in_handle_buf},
477                                 out_handle = {.value = out_handle_buf},
478                                 ctx_token = {.value = NULL},
479                                 ignore_out_tok = {.value = NULL},
480         /* XXX isn't there a define for this?: */
481                                 null_token = {.value = NULL};
482         uint32_t                lustre_svc;
483         lnet_nid_t              nid;
484         u_int32_t               ret_flags;
485         gss_ctx_id_t            ctx = GSS_C_NO_CONTEXT;
486         gss_name_t              client_name;
487         gss_OID                 mech = GSS_C_NO_OID;
488         gss_cred_id_t           svc_cred;
489         u_int32_t               maj_stat = GSS_S_FAILURE, min_stat = 0;
490         u_int32_t               ignore_min_stat;
491         int                     get_len;
492         struct svc_cred         cred;
493         static char             *lbuf = NULL;
494         static int              lbuflen = 0;
495         static char             *cp;
496
497         printerr(2, "handling null request\n");
498
499         if (readline(fileno(f), &lbuf, &lbuflen) != 1) {
500                 printerr(0, "WARNING: handle_nullreq: "
501                             "failed reading request\n");
502                 return -1;
503         }
504
505         cp = lbuf;
506
507         qword_get(&cp, (char *) &lustre_svc, sizeof(lustre_svc));
508         qword_get(&cp, (char *) &nid, sizeof(nid));
509         qword_get(&cp, (char *) &handle_seq, sizeof(handle_seq));
510         printerr(2, "handling req: svc %u, nid %016llx, idx %llx\n",
511                  lustre_svc, nid, handle_seq);
512
513         get_len = qword_get(&cp, in_handle.value, sizeof(in_handle_buf));
514         if (get_len < 0) {
515                 printerr(0, "WARNING: handle_nullreq: "
516                             "failed parsing request\n");
517                 goto out_err;
518         }
519         in_handle.length = (size_t)get_len;
520
521         printerr(3, "in_handle:\n");
522         print_hexl(3, in_handle.value, in_handle.length);
523
524         get_len = qword_get(&cp, in_tok.value, sizeof(in_tok_buf));
525         if (get_len < 0) {
526                 printerr(0, "WARNING: handle_nullreq: "
527                             "failed parsing request\n");
528                 goto out_err;
529         }
530         in_tok.length = (size_t)get_len;
531
532         printerr(3, "in_tok:\n");
533         print_hexl(3, in_tok.value, in_tok.length);
534
535         if (in_handle.length != 0) { /* CONTINUE_INIT case */
536                 if (in_handle.length != sizeof(ctx)) {
537                         printerr(0, "WARNING: handle_nullreq: "
538                                     "input handle has unexpected length %d\n",
539                                     in_handle.length);
540                         goto out_err;
541                 }
542                 /* in_handle is the context id stored in the out_handle
543                  * for the GSS_S_CONTINUE_NEEDED case below.  */
544                 memcpy(&ctx, in_handle.value, in_handle.length);
545         }
546
547         svc_cred = gssd_select_svc_cred(lustre_svc);
548         if (!svc_cred) {
549                 printerr(0, "no service credential for svc %u\n", lustre_svc);
550                 goto out_err;
551         }
552
553         maj_stat = gss_accept_sec_context(&min_stat, &ctx, svc_cred,
554                         &in_tok, GSS_C_NO_CHANNEL_BINDINGS, &client_name,
555                         &mech, &out_tok, &ret_flags, NULL, NULL);
556
557         if (maj_stat == GSS_S_CONTINUE_NEEDED) {
558                 printerr(1, "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
559
560                 /* Save the context handle for future calls */
561                 out_handle.length = sizeof(ctx);
562                 memcpy(out_handle.value, &ctx, sizeof(ctx));
563                 goto continue_needed;
564         }
565         else if (maj_stat != GSS_S_COMPLETE) {
566                 printerr(0, "WARNING: gss_accept_sec_context failed\n");
567                 pgsserr("handle_nullreq: gss_accept_sec_context",
568                         maj_stat, min_stat, mech);
569                 goto out_err;
570         }
571
572         if (get_ids(client_name, mech, &cred, nid, lustre_svc)) {
573                 /* get_ids() prints error msg */
574                 maj_stat = GSS_S_BAD_NAME; /* XXX ? */
575                 gss_release_name(&ignore_min_stat, &client_name);
576                 goto out_err;
577         }
578         gss_release_name(&ignore_min_stat, &client_name);
579
580         /* Context complete. Pass handle_seq in out_handle to use
581          * for context lookup in the kernel. */
582         out_handle.length = sizeof(handle_seq);
583         memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
584
585         /* kernel needs ctx to calculate verifier on null response, so
586          * must give it context before doing null call: */
587         if (serialize_context_for_kernel(ctx, &ctx_token, mech)) {
588                 printerr(0, "WARNING: handle_nullreq: "
589                             "serialize_context_for_kernel failed\n");
590                 maj_stat = GSS_S_FAILURE;
591                 goto out_err;
592         }
593         /* We no longer need the gss context */
594         gss_delete_sec_context(&ignore_min_stat, &ctx, &ignore_out_tok);
595
596         do_svc_downcall(&out_handle, &cred, mech, &ctx_token);
597 continue_needed:
598         send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
599                         &out_handle, &out_tok);
600 out:
601         if (ctx_token.value != NULL)
602                 free(ctx_token.value);
603         if (out_tok.value != NULL)
604                 gss_release_buffer(&ignore_min_stat, &out_tok);
605         return 0;
606
607 out_err:
608         if (ctx != GSS_C_NO_CONTEXT)
609                 gss_delete_sec_context(&ignore_min_stat, &ctx, &ignore_out_tok);
610         send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
611                         &null_token, &null_token);
612         goto out;
613 }