Whamcloud - gitweb
LU-6490 gss: 3.1x kernels adjustments for gssapi code
[fs/lustre-release.git] / lustre / utils / gss / gssd_proc.c
1 /*
2   gssd_proc.c
3
4   Copyright (c) 2000-2004 The Regents of the University of Michigan.
5   All rights reserved.
6
7   Copyright (c) 2000 Dug Song <dugsong@UMICH.EDU>.
8   Copyright (c) 2001 Andy Adamson <andros@UMICH.EDU>.
9   Copyright (c) 2002 Marius Aamodt Eriksen <marius@UMICH.EDU>.
10   Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
11   Copyright (c) 2004 Kevin Coffman <kwc@umich.edu>
12   All rights reserved, all wrongs reversed.
13
14   Redistribution and use in source and binary forms, with or without
15   modification, are permitted provided that the following conditions
16   are met:
17
18   1. Redistributions of source code must retain the above copyright
19      notice, this list of conditions and the following disclaimer.
20   2. Redistributions in binary form must reproduce the above copyright
21      notice, this list of conditions and the following disclaimer in the
22      documentation and/or other materials provided with the distribution.
23   3. Neither the name of the University nor the names of its
24      contributors may be used to endorse or promote products derived
25      from this software without specific prior written permission.
26
27   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
28   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
29   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30   DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
31   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
32   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
33   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
34   BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
35   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
36   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
37   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38
39 */
40
41 #ifndef _GNU_SOURCE
42 #define _GNU_SOURCE
43 #endif
44 #include "config.h"
45 #include <sys/param.h>
46 #include <sys/stat.h>
47 #include <sys/socket.h>
48 #include <arpa/inet.h>
49 #include <sys/fsuid.h>
50
51 #include <stdio.h>
52 #include <stdlib.h>
53 #include <pwd.h>
54 #include <grp.h>
55 #include <string.h>
56 #include <dirent.h>
57 #include <poll.h>
58 #include <fcntl.h>
59 #include <signal.h>
60 #include <unistd.h>
61 #include <errno.h>
62 #include <gssapi/gssapi.h>
63 #ifdef HAVE_NETDB_H
64 # include <netdb.h>
65 #endif
66
67 #include "gssd.h"
68 #include "err_util.h"
69 #include "gss_util.h"
70 #include "gss_oids.h"
71 #include "krb5_util.h"
72 #include "context.h"
73 #include "lsupport.h"
74
75 /*
76  * pollarray:
77  *      array of struct pollfd suitable to pass to poll. initialized to
78  *      zero - a zero struct is ignored by poll() because the events mask is 0.
79  *
80  * clnt_list:
81  *      linked list of struct clnt_info which associates a clntXXX directory
82  *      with an index into pollarray[], and other basic data about that client.
83  *
84  * Directory structure: created by the kernel nfs client
85  *      {pipefs_nfsdir}/clntXX             : one per rpc_clnt struct in the kernel
86  *      {pipefs_nfsdir}/clntXX/krb5        : read uid for which kernel wants
87  *                                          a context, write the resulting context
88  *      {pipefs_nfsdir}/clntXX/info        : stores info such as server name
89  *
90  * Algorithm:
91  *      Poll all {pipefs_nfsdir}/clntXX/krb5 files.  When ready, data read
92  *      is a uid; performs rpcsec_gss context initialization protocol to
93  *      get a cred for that user.  Writes result to corresponding krb5 file
94  *      in a form the kernel code will understand.
95  *      In addition, we make sure we are notified whenever anything is
96  *      created or destroyed in {pipefs_nfsdir} or in an of the clntXX directories,
97  *      and rescan the whole {pipefs_nfsdir} when this happens.
98  */
99
100 struct pollfd * pollarray;
101
102 int pollsize;  /* the size of pollaray (in pollfd's) */
103
104 static void
105 destroy_client(struct clnt_info *clp)
106 {
107         printerr(3, "clp %p: dirname %s, krb5fd %d\n", clp, clp->dirname, clp->krb5_fd);
108         if (clp->krb5_poll_index != -1)
109                 memset(&pollarray[clp->krb5_poll_index], 0,
110                                         sizeof(struct pollfd));
111         if (clp->spkm3_poll_index != -1)
112                 memset(&pollarray[clp->spkm3_poll_index], 0,
113                                         sizeof(struct pollfd));
114         if (clp->dir_fd != -1) close(clp->dir_fd);
115         if (clp->krb5_fd != -1) close(clp->krb5_fd);
116         if (clp->spkm3_fd != -1) close(clp->spkm3_fd);
117         if (clp->dirname) free(clp->dirname);
118         if (clp->servicename) free(clp->servicename);
119         free(clp);
120 }
121
122 static struct clnt_info *
123 insert_new_clnt(void)
124 {
125         struct clnt_info        *clp = NULL;
126
127         if (!(clp = (struct clnt_info *)calloc(1,sizeof(struct clnt_info)))) {
128                 printerr(0, "ERROR: can't malloc clnt_info: %s\n",
129                          strerror(errno));
130                 goto out;
131         }
132         clp->krb5_poll_index = -1;
133         clp->spkm3_poll_index = -1;
134         clp->krb5_fd = -1;
135         clp->spkm3_fd = -1;
136         clp->dir_fd = -1;
137
138         TAILQ_INSERT_HEAD(&clnt_list, clp, list);
139 out:
140         return clp;
141 }
142
143 static int
144 process_clnt_dir_files(struct clnt_info * clp)
145 {
146         char    kname[32];
147         char    sname[32];
148
149         if (clp->krb5_fd == -1) {
150                 snprintf(kname, sizeof(kname), "%s/krb5", clp->dirname);
151                 clp->krb5_fd = open(kname, O_RDWR);
152         }
153         if (clp->spkm3_fd == -1) {
154                 snprintf(sname, sizeof(sname), "%s/spkm3", clp->dirname);
155                 clp->spkm3_fd = open(sname, O_RDWR);
156         }
157         if((clp->krb5_fd == -1) && (clp->spkm3_fd == -1))
158                 return -1;
159         return 0;
160 }
161
162 static int
163 get_poll_index(int *ind)
164 {
165         int i;
166
167         *ind = -1;
168         for (i=0; i<FD_ALLOC_BLOCK; i++) {
169                 if (pollarray[i].events == 0) {
170                         *ind = i;
171                         break;
172                 }
173         }
174         if (*ind == -1) {
175                 printerr(0, "ERROR: No pollarray slots open\n");
176                 return -1;
177         }
178         return 0;
179 }
180
181
182 static int
183 insert_clnt_poll(struct clnt_info *clp)
184 {
185         if ((clp->krb5_fd != -1) && (clp->krb5_poll_index == -1)) {
186                 if (get_poll_index(&clp->krb5_poll_index)) {
187                         printerr(0, "ERROR: Too many krb5 clients\n");
188                         return -1;
189                 }
190                 pollarray[clp->krb5_poll_index].fd = clp->krb5_fd;
191                 pollarray[clp->krb5_poll_index].events |= POLLIN;
192                 printerr(2, "monitoring krb5 channel under %s\n",
193                          clp->dirname);
194         }
195
196         if ((clp->spkm3_fd != -1) && (clp->spkm3_poll_index == -1)) {
197                 if (get_poll_index(&clp->spkm3_poll_index)) {
198                         printerr(0, "ERROR: Too many spkm3 clients\n");
199                         return -1;
200                 }
201                 pollarray[clp->spkm3_poll_index].fd = clp->spkm3_fd;
202                 pollarray[clp->spkm3_poll_index].events |= POLLIN;
203         }
204
205         return 0;
206 }
207
208 static void
209 process_clnt_dir(char *dir)
210 {
211         struct clnt_info *      clp;
212
213         if (!(clp = insert_new_clnt()))
214                 goto fail_destroy_client;
215
216         if (!(clp->dirname = calloc(strlen(dir) + 1, 1))) {
217                 goto fail_destroy_client;
218         }
219         memcpy(clp->dirname, dir, strlen(dir));
220         if ((clp->dir_fd = open(clp->dirname, O_RDONLY)) == -1) {
221                 printerr(0, "ERROR: can't open %s: %s\n",
222                          clp->dirname, strerror(errno));
223                 goto fail_destroy_client;
224         }
225         fcntl(clp->dir_fd, F_SETSIG, DNOTIFY_SIGNAL);
226         fcntl(clp->dir_fd, F_NOTIFY, DN_CREATE | DN_DELETE | DN_MULTISHOT);
227
228         if (process_clnt_dir_files(clp))
229                 goto fail_keep_client;
230
231         if (insert_clnt_poll(clp))
232                 goto fail_destroy_client;
233
234         return;
235
236 fail_destroy_client:
237         if (clp) {
238                 TAILQ_REMOVE(&clnt_list, clp, list);
239                 destroy_client(clp);
240         }
241 fail_keep_client:
242         /* We couldn't find some subdirectories, but we keep the client
243          * around in case we get a notification on the directory when the
244          * subdirectories are created. */
245         return;
246 }
247
248 void
249 init_client_list(void)
250 {
251         TAILQ_INIT(&clnt_list);
252         /* Eventually plan to grow/shrink poll array: */
253         pollsize = FD_ALLOC_BLOCK;
254         pollarray = calloc(pollsize, sizeof(struct pollfd));
255 }
256
257 /*
258  * This is run after a DNOTIFY signal, and should clear up any
259  * directories that are no longer around, and re-scan any existing
260  * directories, since the DNOTIFY could have been in there.
261  */
262 static void
263 update_old_clients(struct dirent **namelist, int size)
264 {
265         struct clnt_info *clp;
266         void *saveprev;
267         int i, stillhere;
268
269         for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
270                 stillhere = 0;
271                 for (i=0; i < size; i++) {
272                         if (!strcmp(clp->dirname, namelist[i]->d_name)) {
273                                 stillhere = 1;
274                                 break;
275                         }
276                 }
277                 if (!stillhere) {
278                         printerr(2, "destroying client %s\n", clp->dirname);
279                         saveprev = clp->list.tqe_prev;
280                         TAILQ_REMOVE(&clnt_list, clp, list);
281                         destroy_client(clp);
282                         clp = saveprev;
283                 }
284         }
285         for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
286                 if (!process_clnt_dir_files(clp))
287                         insert_clnt_poll(clp);
288         }
289 }
290
291 /* Search for a client by directory name, return 1 if found, 0 otherwise */
292 static int
293 find_client(char *dirname)
294 {
295         struct clnt_info        *clp;
296
297         for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next)
298                 if (!strcmp(clp->dirname, dirname))
299                         return 1;
300         return 0;
301 }
302
303 /* Used to read (and re-read) list of clients, set up poll array. */
304 int
305 update_client_list(void)
306 {
307         char lustre_dir[PATH_MAX];
308         struct dirent lustre_dirent = { .d_name = "lustre" };
309         struct dirent *namelist[1];
310         struct stat statbuf;
311         int i, j;
312
313         if (chdir(pipefs_dir) < 0) {
314                 printerr(0, "ERROR: can't chdir to %s: %s\n",
315                          pipefs_dir, strerror(errno));
316                 return -1;
317         }
318
319         snprintf(lustre_dir, sizeof(lustre_dir), "%s/%s", pipefs_dir, "lustre");
320         if (stat(lustre_dir, &statbuf) == 0) {
321                 namelist[0] = &lustre_dirent;
322                 j = 1;
323                 printerr(2, "re-processing lustre directory\n");
324         } else {
325                 namelist[0] = NULL;
326                 j = 0;
327                 printerr(2, "lustre directory not exist\n");
328         }
329
330         update_old_clients(namelist, j);
331         for (i=0; i < j; i++) {
332                 if (i < FD_ALLOC_BLOCK && !find_client(namelist[i]->d_name))
333                         process_clnt_dir(namelist[i]->d_name);
334         }
335
336         chdir("/");
337         return 0;
338 }
339
340 /* Context creation response. */
341 struct lustre_gss_init_res {
342         gss_buffer_desc gr_ctx;         /* context handle */
343         unsigned int    gr_major;       /* major status */
344         unsigned int    gr_minor;       /* minor status */
345         unsigned int    gr_win;         /* sequence window */
346         gss_buffer_desc gr_token;       /* token */
347 };
348
349 struct lustre_gss_data {
350         int             lgd_established;
351         int             lgd_lustre_svc; /* mds/oss */
352         int             lgd_uid;        /* uid */
353         char           *lgd_uuid;       /* client device uuid */
354         gss_name_t      lgd_name;       /* service name */
355
356         gss_OID         lgd_mech;       /* mech OID */
357         unsigned int    lgd_req_flags;  /* request flags */
358         gss_cred_id_t   lgd_cred;       /* credential */
359         gss_ctx_id_t    lgd_ctx;        /* session context */
360         gss_buffer_desc lgd_rmt_ctx;    /* remote handle of context */
361         uint32_t        lgd_seq_win;    /* sequence window */
362
363         int             lgd_rpc_err;
364         int             lgd_gss_err;
365 };
366
367 static int
368 do_downcall(int k5_fd, struct lgssd_upcall_data *updata,
369             struct lustre_gss_data *lgd, gss_buffer_desc *context_token)
370 {
371         char    *buf = NULL, *p = NULL, *end = NULL;
372         unsigned int timeout = 0; /* XXX decide on a reasonable value */
373         unsigned int buf_size = 0;
374
375         printerr(2, "doing downcall\n");
376         buf_size = sizeof(updata->seq) + sizeof(timeout) +
377                 sizeof(lgd->lgd_seq_win) +
378                 sizeof(lgd->lgd_rmt_ctx.length) + lgd->lgd_rmt_ctx.length +
379                 sizeof(context_token->length) + context_token->length;
380         p = buf = malloc(buf_size);
381         end = buf + buf_size;
382
383         if (WRITE_BYTES(&p, end, updata->seq)) goto out_err;
384         /* Not setting any timeout for now: */
385         if (WRITE_BYTES(&p, end, timeout)) goto out_err;
386         if (WRITE_BYTES(&p, end, lgd->lgd_seq_win)) goto out_err;
387         if (write_buffer(&p, end, &lgd->lgd_rmt_ctx)) goto out_err;
388         if (write_buffer(&p, end, context_token)) goto out_err;
389
390         lgssd_mutex_get(lgssd_mutex_downcall);
391         if (write(k5_fd, buf, p - buf) < p - buf) {
392                 lgssd_mutex_put(lgssd_mutex_downcall);
393                 goto out_err;
394         }
395         lgssd_mutex_put(lgssd_mutex_downcall);
396
397         if (buf) free(buf);
398         return 0;
399 out_err:
400         if (buf) free(buf);
401         printerr(0, "ERROR: Failed to write downcall!\n");
402         return -1;
403 }
404
405 static int
406 do_error_downcall(int k5_fd, uint32_t seq, int rpc_err, int gss_err)
407 {
408         char    buf[1024];
409         char    *p = buf, *end = buf + 1024;
410         unsigned int timeout = 0;
411         int     zero = 0;
412
413         printerr(1, "doing error downcall\n");
414
415         if (WRITE_BYTES(&p, end, seq)) goto out_err;
416         if (WRITE_BYTES(&p, end, timeout)) goto out_err;
417         /* use seq_win = 0 to indicate an error: */
418         if (WRITE_BYTES(&p, end, zero)) goto out_err;
419         if (WRITE_BYTES(&p, end, rpc_err)) goto out_err;
420         if (WRITE_BYTES(&p, end, gss_err)) goto out_err;
421
422         lgssd_mutex_get(lgssd_mutex_downcall);
423         if (write(k5_fd, buf, p - buf) < p - buf) {
424                 lgssd_mutex_put(lgssd_mutex_downcall);
425                 goto out_err;
426         }
427         lgssd_mutex_put(lgssd_mutex_downcall);
428         return 0;
429 out_err:
430         printerr(0, "Failed to write error downcall!\n");
431         return -1;
432 }
433
434 #if 0
435 /*
436  * Create an RPC connection and establish an authenticated
437  * gss context with a server.
438  */
439 int create_auth_rpc_client(struct clnt_info *clp,
440                            CLIENT **clnt_return,
441                            AUTH **auth_return,
442                            uid_t uid,
443                            int authtype)
444 {
445         CLIENT                  *rpc_clnt = NULL;
446         struct rpc_gss_sec      sec;
447         AUTH                    *auth = NULL;
448         uid_t                   save_uid = -1;
449         int                     retval = -1;
450         int                     errcode;
451         OM_uint32               min_stat;
452         char                    rpc_errmsg[1024];
453         int                     sockp = RPC_ANYSOCK;
454         int                     sendsz = 32768, recvsz = 32768;
455         struct addrinfo         ai_hints, *a = NULL;
456         char                    service[64];
457         char                    *at_sign;
458
459         /* Create the context as the user (not as root) */
460         save_uid = geteuid();
461         if (setfsuid(uid) != 0) {
462                 printerr(0, "WARNING: Failed to setfsuid for "
463                             "user with uid %d\n", uid);
464                 goto out_fail;
465         }
466         printerr(2, "creating context using fsuid %d (save_uid %d)\n",
467                         uid, save_uid);
468
469         sec.qop = GSS_C_QOP_DEFAULT;
470         sec.svc = RPCSEC_GSS_SVC_NONE;
471         sec.cred = GSS_C_NO_CREDENTIAL;
472         sec.req_flags = 0;
473         if (authtype == AUTHTYPE_KRB5) {
474                 sec.mech = (gss_OID)&krb5oid;
475                 sec.req_flags = GSS_C_MUTUAL_FLAG;
476         }
477         else if (authtype == AUTHTYPE_SPKM3) {
478                 sec.mech = (gss_OID)&spkm3oid;
479                 /* XXX sec.req_flags = GSS_C_ANON_FLAG;
480                  * Need a way to switch....
481                  */
482                 sec.req_flags = GSS_C_MUTUAL_FLAG;
483         }
484         else {
485                 printerr(0, "ERROR: Invalid authentication type (%d) "
486                         "in create_auth_rpc_client\n", authtype);
487                 goto out_fail;
488         }
489
490
491         if (authtype == AUTHTYPE_KRB5) {
492 #ifdef HAVE_SET_ALLOWABLE_ENCTYPES
493                 /*
494                  * Do this before creating rpc connection since we won't need
495                  * rpc connection if it fails!
496                  */
497                 if (limit_krb5_enctypes(&sec, uid)) {
498                         printerr(1, "WARNING: Failed while limiting krb5 "
499                                     "encryption types for user with uid %d\n",
500                                  uid);
501                         goto out_fail;
502                 }
503 #endif
504         }
505
506         /* create an rpc connection to the nfs server */
507
508         printerr(2, "creating %s client for server %s\n", clp->protocol,
509                         clp->servername);
510
511         memset(&ai_hints, '\0', sizeof(ai_hints));
512         ai_hints.ai_family = PF_INET;
513         ai_hints.ai_flags |= AI_CANONNAME;
514         if ((strcmp(clp->protocol, "tcp")) == 0) {
515                 ai_hints.ai_socktype = SOCK_STREAM;
516                 ai_hints.ai_protocol = IPPROTO_TCP;
517         } else if ((strcmp(clp->protocol, "udp")) == 0) {
518                 ai_hints.ai_socktype = SOCK_DGRAM;
519                 ai_hints.ai_protocol = IPPROTO_UDP;
520         } else {
521                 printerr(0, "WARNING: unrecognized protocol, '%s', requested "
522                          "for connection to server %s for user with uid %d",
523                          clp->protocol, clp->servername, uid);
524                 goto out_fail;
525         }
526
527         /* extract the service name from clp->servicename */
528         if ((at_sign = strchr(clp->servicename, '@')) == NULL) {
529                 printerr(0, "WARNING: servicename (%s) not formatted as "
530                         "expected with service@host", clp->servicename);
531                 goto out_fail;
532         }
533         if ((at_sign - clp->servicename) >= sizeof(service)) {
534                 printerr(0, "WARNING: service portion of servicename (%s) "
535                         "is too long!", clp->servicename);
536                 goto out_fail;
537         }
538         strncpy(service, clp->servicename, at_sign - clp->servicename);
539         service[at_sign - clp->servicename] = '\0';
540
541         errcode = getaddrinfo(clp->servername, service, &ai_hints, &a);
542         if (errcode) {
543                 printerr(0, "WARNING: Error from getaddrinfo for server "
544                          "'%s': %s", clp->servername, gai_strerror(errcode));
545                 goto out_fail;
546         }
547
548         if (a == NULL) {
549                 printerr(0, "WARNING: No address information found for "
550                          "connection to server %s for user with uid %d",
551                          clp->servername, uid);
552                 goto out_fail;
553         }
554         if (a->ai_protocol == IPPROTO_TCP) {
555                 if ((rpc_clnt = clnttcp_create(
556                                         (struct sockaddr_in *) a->ai_addr,
557                                         clp->prog, clp->vers, &sockp,
558                                         sendsz, recvsz)) == NULL) {
559                         snprintf(rpc_errmsg, sizeof(rpc_errmsg),
560                                  "WARNING: can't create tcp rpc_clnt "
561                                  "for server %s for user with uid %d",
562                                  clp->servername, uid);
563                         printerr(0, "%s\n",
564                                  clnt_spcreateerror(rpc_errmsg));
565                         goto out_fail;
566                 }
567         } else if (a->ai_protocol == IPPROTO_UDP) {
568                 const struct timeval timeout = {5, 0};
569                 if ((rpc_clnt = clntudp_bufcreate(
570                                         (struct sockaddr_in *) a->ai_addr,
571                                         clp->prog, clp->vers, timeout,
572                                         &sockp, sendsz, recvsz)) == NULL) {
573                         snprintf(rpc_errmsg, sizeof(rpc_errmsg),
574                                  "WARNING: can't create udp rpc_clnt "
575                                  "for server %s for user with uid %d",
576                                  clp->servername, uid);
577                         printerr(0, "%s\n",
578                                  clnt_spcreateerror(rpc_errmsg));
579                         goto out_fail;
580                 }
581         } else {
582                 /* Shouldn't happen! */
583                 printerr(0, "ERROR: requested protocol '%s', but "
584                          "got addrinfo with protocol %d",
585                          clp->protocol, a->ai_protocol);
586                 goto out_fail;
587         }
588         /* We're done with this */
589         freeaddrinfo(a);
590         a = NULL;
591
592         printerr(2, "creating context with server %s\n", clp->servicename);
593         auth = authgss_create_default(rpc_clnt, clp->servicename, &sec);
594         if (!auth) {
595                 /* Our caller should print appropriate message */
596                 printerr(2, "WARNING: Failed to create %s context for "
597                             "user with uid %d for server %s\n",
598                         (authtype == AUTHTYPE_KRB5 ? "krb5":"spkm3"),
599                          uid, clp->servername);
600                 goto out_fail;
601         }
602
603         /* Success !!! */
604         rpc_clnt->cl_auth = auth;
605         *clnt_return = rpc_clnt;
606         *auth_return = auth;
607         retval = 0;
608
609   out:
610         if (sec.cred != GSS_C_NO_CREDENTIAL)
611                 gss_release_cred(&min_stat, &sec.cred);
612         if (a != NULL) freeaddrinfo(a);
613         /* Restore euid to original value */
614         if ((save_uid != -1) && (setfsuid(save_uid) != uid)) {
615                 printerr(0, "WARNING: Failed to restore fsuid"
616                             " to uid %d from %d\n", save_uid, uid);
617         }
618         return retval;
619
620   out_fail:
621         /* Only destroy here if failure.  Otherwise, caller is responsible */
622         if (rpc_clnt) clnt_destroy(rpc_clnt);
623
624         goto out;
625 }
626 #endif
627
628 static
629 int do_negotiation(struct lustre_gss_data *lgd,
630                    gss_buffer_desc *gss_token,
631                    struct lustre_gss_init_res *gr,
632                    int timeout)
633 {
634         char *file = "/proc/fs/lustre/sptlrpc/gss/init_channel";
635         struct lgssd_ioctl_param param;
636         struct passwd *pw;
637         int fd, ret;
638         char outbuf[8192];
639         unsigned int *p;
640         int res;
641
642         pw = getpwuid(lgd->lgd_uid);
643         if (!pw) {
644                 printerr(0, "no uid %u in local user database\n",
645                          lgd->lgd_uid);
646                 return -1;
647         }
648
649         param.version = GSSD_INTERFACE_VERSION;
650         param.uuid = lgd->lgd_uuid;
651         param.lustre_svc = lgd->lgd_lustre_svc;
652         param.uid = lgd->lgd_uid;
653         param.gid = pw->pw_gid;
654         param.send_token_size = gss_token->length;
655         param.send_token = (char *) gss_token->value;
656         param.reply_buf_size = sizeof(outbuf);
657         param.reply_buf = outbuf;
658
659         fd = open(file, O_RDWR);
660         if (fd < 0) {
661                 printerr(0, "can't open file %s\n", file);
662                 return -1;
663         }
664
665         ret = write(fd, &param, sizeof(param));
666
667         if (ret != sizeof(param)) {
668                 printerr(0, "lustre ioctl err: %d\n", strerror(errno));
669                 close(fd);
670                 return -1;
671         }
672         if (param.status) {
673                 close(fd);
674                 printerr(0, "status: %d (%s)\n",
675                          param.status, strerror((int)param.status));
676                 if (param.status == -ETIMEDOUT) {
677                         /* kernel return -ETIMEDOUT means the rpc timedout,
678                          * we should notify the caller to reinitiate the
679                          * gss negotiation, by return -ERESTART
680                          */
681                         lgd->lgd_rpc_err = -ERESTART;
682                         lgd->lgd_gss_err = 0;
683                 } else {
684                         lgd->lgd_rpc_err = param.status;
685                         lgd->lgd_gss_err = 0;
686                 }
687                 return -1;
688         }
689         p = (unsigned int *)outbuf;
690         res = *p++;
691         gr->gr_major = *p++;
692         gr->gr_minor = *p++;
693         gr->gr_win = *p++;
694
695         gr->gr_ctx.length = *p++;
696         gr->gr_ctx.value = malloc(gr->gr_ctx.length);
697         memcpy(gr->gr_ctx.value, p, gr->gr_ctx.length);
698         p += (((gr->gr_ctx.length + 3) & ~3) / 4);
699
700         gr->gr_token.length = *p++;
701         gr->gr_token.value = malloc(gr->gr_token.length);
702         memcpy(gr->gr_token.value, p, gr->gr_token.length);
703         p += (((gr->gr_token.length + 3) & ~3) / 4);
704
705         printerr(2, "do_negotiation: receive handle len %d, token len %d\n",
706                  gr->gr_ctx.length, gr->gr_token.length);
707         close(fd);
708         return 0;
709 }
710
711 static
712 int gssd_refresh_lgd(struct lustre_gss_data *lgd)
713 {
714         struct lustre_gss_init_res gr;
715         gss_buffer_desc         *recv_tokenp, send_token;
716         OM_uint32                maj_stat, min_stat, call_stat, ret_flags;
717
718         /* GSS context establishment loop. */
719         memset(&gr, 0, sizeof(gr));
720         recv_tokenp = GSS_C_NO_BUFFER;
721
722         for (;;) {
723                 /* print the token we just received */
724                 if (recv_tokenp != GSS_C_NO_BUFFER) {
725                         printerr(3, "The received token length %d\n",
726                                  recv_tokenp->length);
727                         print_hexl(3, recv_tokenp->value, recv_tokenp->length);
728                 }
729
730                 maj_stat = gss_init_sec_context(&min_stat,
731                                                 lgd->lgd_cred,
732                                                 &lgd->lgd_ctx,
733                                                 lgd->lgd_name,
734                                                 lgd->lgd_mech,
735                                                 lgd->lgd_req_flags,
736                                                 0,              /* time req */
737                                                 NULL,           /* channel */
738                                                 recv_tokenp,
739                                                 NULL,           /* used mech */
740                                                 &send_token,
741                                                 &ret_flags,
742                                                 NULL);          /* time rec */
743
744                 if (recv_tokenp != GSS_C_NO_BUFFER) {
745                         gss_release_buffer(&min_stat, &gr.gr_token);
746                         recv_tokenp = GSS_C_NO_BUFFER;
747                 }
748                 if (maj_stat != GSS_S_COMPLETE &&
749                     maj_stat != GSS_S_CONTINUE_NEEDED) {
750                         pgsserr("gss_init_sec_context", maj_stat, min_stat,
751                                 lgd->lgd_mech);
752                         break;
753                 }
754                 if (send_token.length != 0) {
755                         memset(&gr, 0, sizeof(gr));
756
757                         /* print the token we are about to send */
758                         printerr(3, "token being sent length %d\n",
759                                  send_token.length);
760                         print_hexl(3, send_token.value, send_token.length);
761
762                         call_stat = do_negotiation(lgd, &send_token, &gr, 0);
763                         gss_release_buffer(&min_stat, &send_token);
764
765                         if (call_stat != 0 ||
766                             (gr.gr_major != GSS_S_COMPLETE &&
767                              gr.gr_major != GSS_S_CONTINUE_NEEDED)) {
768                                 printerr(0, "call stat %d, major stat 0x%x\n",
769                                          (int)call_stat, gr.gr_major);
770                                 return -1;
771                         }
772
773                         if (gr.gr_ctx.length != 0) {
774                                 if (lgd->lgd_rmt_ctx.value)
775                                         gss_release_buffer(&min_stat,
776                                                            &lgd->lgd_rmt_ctx);
777                                 lgd->lgd_rmt_ctx = gr.gr_ctx;
778                         }
779                         if (gr.gr_token.length != 0) {
780                                 if (maj_stat != GSS_S_CONTINUE_NEEDED)
781                                         break;
782                                 recv_tokenp = &gr.gr_token;
783                         }
784                 }
785
786                 /* GSS_S_COMPLETE => check gss header verifier,
787                  * usually checked in gss_validate
788                  */
789                 if (maj_stat == GSS_S_COMPLETE) {
790                         lgd->lgd_established = 1;
791                         lgd->lgd_seq_win = gr.gr_win;
792                         break;
793                 }
794         }
795         /* End context negotiation loop. */
796         if (!lgd->lgd_established) {
797                 if (gr.gr_token.length != 0)
798                         gss_release_buffer(&min_stat, &gr.gr_token);
799
800                 printerr(0, "context negotiation failed\n");
801                 return -1;
802         }
803
804         printerr(2, "successfully refreshed lgd\n");
805         return 0;
806 }
807
808 static
809 int gssd_create_lgd(struct clnt_info *clp,
810                     struct lustre_gss_data *lgd,
811                     struct lgssd_upcall_data *updata,
812                     int authtype)
813 {
814         gss_buffer_desc         sname;
815         OM_uint32               maj_stat, min_stat;
816         int                     retval = -1;
817
818         lgd->lgd_established = 0;
819         lgd->lgd_lustre_svc = updata->svc;
820         lgd->lgd_uid = updata->uid;
821         lgd->lgd_uuid = updata->obd;
822
823         switch (authtype) {
824         case AUTHTYPE_KRB5:
825                 lgd->lgd_mech = (gss_OID) &krb5oid;
826                 lgd->lgd_req_flags = GSS_C_MUTUAL_FLAG;
827                 break;
828         case AUTHTYPE_SPKM3:
829                 lgd->lgd_mech = (gss_OID) &spkm3oid;
830                 /* XXX sec.req_flags = GSS_C_ANON_FLAG;
831                  * Need a way to switch....
832                  */
833                 lgd->lgd_req_flags = GSS_C_MUTUAL_FLAG;
834                 break;
835         default:
836                 printerr(0, "Invalid authentication type (%d)\n", authtype);
837                 return -1;
838         }
839
840         lgd->lgd_cred = GSS_C_NO_CREDENTIAL;
841         lgd->lgd_ctx = GSS_C_NO_CONTEXT;
842         lgd->lgd_rmt_ctx = (gss_buffer_desc) GSS_C_EMPTY_BUFFER;
843         lgd->lgd_seq_win = 0;
844
845         sname.value = clp->servicename;
846         sname.length = strlen(clp->servicename);
847
848         maj_stat = gss_import_name(&min_stat, &sname,
849                                    (gss_OID) GSS_C_NT_HOSTBASED_SERVICE,
850                                    &lgd->lgd_name);
851         if (maj_stat != GSS_S_COMPLETE) {
852                 pgsserr(0, maj_stat, min_stat, lgd->lgd_mech);
853                 goto out_fail;
854         }
855
856         retval = gssd_refresh_lgd(lgd);
857
858         if (lgd->lgd_name != GSS_C_NO_NAME)
859                 gss_release_name(&min_stat, &lgd->lgd_name);
860
861         if (lgd->lgd_cred != GSS_C_NO_CREDENTIAL)
862                 gss_release_cred(&min_stat, &lgd->lgd_cred);
863
864   out_fail:
865         return retval;
866 }
867
868 static
869 void gssd_free_lgd(struct lustre_gss_data *lgd)
870 {
871         gss_buffer_t            token = GSS_C_NO_BUFFER;
872         OM_uint32               maj_stat, min_stat;
873
874         if (lgd->lgd_ctx == GSS_C_NO_CONTEXT)
875                 return;
876
877         maj_stat = gss_delete_sec_context(&min_stat, &lgd->lgd_ctx, token);
878 }
879
880 static
881 int construct_service_name(struct clnt_info *clp,
882                            struct lgssd_upcall_data *ud)
883 {
884         const int buflen = 256;
885         char name[buflen];
886
887         if (clp->servicename) {
888                 free(clp->servicename);
889                 clp->servicename = NULL;
890         }
891
892         if (lnet_nid2hostname(ud->nid, name, buflen))
893                 return -1;
894
895         clp->servicename = malloc(32 + strlen(name));
896         if (!clp->servicename) {
897                 printerr(0, "can't alloc memory\n");
898                 return -1;
899         }
900         sprintf(clp->servicename, "%s@%s",
901                 ud->svc == LUSTRE_GSS_SVC_MDS ?
902                 GSSD_SERVICE_MDS : GSSD_SERVICE_OSS,
903                 name);
904         printerr(2, "constructed servicename: %s\n", clp->servicename);
905         return 0;
906 }
907
908 /*
909  * this code uses the userland rpcsec gss library to create a krb5
910  * context on behalf of the kernel
911  */
912 void
913 handle_krb5_upcall(struct clnt_info *clp)
914 {
915         pid_t                   pid;
916         gss_buffer_desc         token = { 0, NULL };
917         struct lgssd_upcall_data updata;
918         struct lustre_gss_data  lgd;
919         char                    **credlist = NULL;
920         char                    **ccname;
921         int                     read_rc;
922
923         printerr(2, "handling krb5 upcall\n");
924
925         memset(&lgd, 0, sizeof(lgd));
926         lgd.lgd_rpc_err = -EPERM; /* default error code */
927
928         read_rc = read(clp->krb5_fd, &updata, sizeof(updata));
929         if (read_rc < 0) {
930                 printerr(0, "WARNING: failed reading from krb5 "
931                             "upcall pipe: %s\n", strerror(errno));
932                 return;
933         } else if (read_rc != sizeof(updata)) {
934                 printerr(0, "upcall data mismatch: length %d, expect %d\n",
935                          read_rc, sizeof(updata));
936
937                 /* the sequence number must be the first field. if read >= 4
938                  * bytes then we know at least sequence is fine, try to send
939                  * error notification nicely.
940                  */
941                 if (read_rc >= 4)
942                         do_error_downcall(clp->krb5_fd, updata.seq, -EPERM, 0);
943                 return;
944         }
945
946         /* FIXME temporary fix, do this before fork.
947          * in case of errors could have memory leak!!!
948          */
949         if (updata.uid == 0) {
950                 if (gssd_get_krb5_machine_cred_list(&credlist)) {
951                         printerr(0, "ERROR: Failed to obtain machine "
952                                     "credentials\n");
953                         do_error_downcall(clp->krb5_fd, updata.seq, -EPERM, 0);
954                         return;
955                 }
956         }
957
958         /* fork child process */
959         pid = fork();
960         if (pid < 0) {
961                 printerr(0, "can't fork: %s\n", strerror(errno));
962                 do_error_downcall(clp->krb5_fd, updata.seq, -EPERM, 0);
963                 return;
964         } else if (pid > 0) {
965                 printerr(2, "forked child process: %d\n", pid);
966                 return;
967         }
968
969         printerr(1, "krb5 upcall: seq %u, uid %u, svc %u, nid 0x%llx, obd %s\n",
970                  updata.seq, updata.uid, updata.svc, updata.nid, updata.obd);
971
972         if (updata.svc != LUSTRE_GSS_SVC_MDS &&
973             updata.svc != LUSTRE_GSS_SVC_OSS) {
974                 printerr(0, "invalid svc %d\n", updata.svc);
975                 lgd.lgd_rpc_err = -EPROTO;
976                 goto out_return_error;
977         }
978         updata.obd[sizeof(updata.obd)-1] = '\0';
979
980         if (construct_service_name(clp, &updata)) {
981                 printerr(0, "failed to construct service name\n");
982                 goto out_return_error;
983         }
984
985         if (updata.uid == 0) {
986                 int success = 0;
987
988                 /*
989                  * Get a list of credential cache names and try each
990                  * of them until one works or we've tried them all
991                  */
992 /*
993                 if (gssd_get_krb5_machine_cred_list(&credlist)) {
994                         printerr(0, "ERROR: Failed to obtain machine "
995                                     "credentials for %s\n", clp->servicename);
996                         goto out_return_error;
997                 }
998 */
999                 for (ccname = credlist; ccname && *ccname; ccname++) {
1000                         gssd_setup_krb5_machine_gss_ccache(*ccname);
1001                         if ((gssd_create_lgd(clp, &lgd, &updata,
1002                                              AUTHTYPE_KRB5)) == 0) {
1003                                 /* Success! */
1004                                 success++;
1005                                 break;
1006                         }
1007                         printerr(2, "WARNING: Failed to create krb5 context "
1008                                     "for user with uid %d with credentials "
1009                                     "cache %s for service %s\n",
1010                                  updata.uid, *ccname, clp->servicename);
1011                 }
1012                 gssd_free_krb5_machine_cred_list(credlist);
1013                 if (!success) {
1014                         printerr(0, "ERROR: Failed to create krb5 context "
1015                                     "for user with uid %d with any "
1016                                     "credentials cache for service %s\n",
1017                                  updata.uid, clp->servicename);
1018                         goto out_return_error;
1019                 }
1020         }
1021         else {
1022                 /* Tell krb5 gss which credentials cache to use */
1023                 gssd_setup_krb5_user_gss_ccache(updata.uid, clp->servicename);
1024
1025                 if ((gssd_create_lgd(clp, &lgd, &updata, AUTHTYPE_KRB5)) != 0) {
1026                         printerr(0, "WARNING: Failed to create krb5 context "
1027                                     "for user with uid %d for service %s\n",
1028                                  updata.uid, clp->servicename);
1029                         goto out_return_error;
1030                 }
1031         }
1032
1033         if (serialize_context_for_kernel(lgd.lgd_ctx, &token, &krb5oid)) {
1034                 printerr(0, "WARNING: Failed to serialize krb5 context for "
1035                             "user with uid %d for service %s\n",
1036                          updata.uid, clp->servicename);
1037                 goto out_return_error;
1038         }
1039
1040         printerr(1, "refreshed: %u@%s for %s\n",
1041                  updata.uid, updata.obd, clp->servicename);
1042         do_downcall(clp->krb5_fd, &updata, &lgd, &token);
1043
1044 out:
1045         if (token.value)
1046                 free(token.value);
1047
1048         gssd_free_lgd(&lgd);
1049         exit(0); /* i'm child process */
1050
1051 out_return_error:
1052         do_error_downcall(clp->krb5_fd, updata.seq,
1053                           lgd.lgd_rpc_err, lgd.lgd_gss_err);
1054         goto out;
1055 }
1056
1057 /*
1058  * this code uses the userland rpcsec gss library to create an spkm3
1059  * context on behalf of the kernel
1060  */
1061 void
1062 handle_spkm3_upcall(struct clnt_info *clp)
1063 {
1064 #if 0
1065         uid_t                   uid;
1066         CLIENT                  *rpc_clnt = NULL;
1067         AUTH                    *auth = NULL;
1068         struct authgss_private_data pd;
1069         gss_buffer_desc         token;
1070
1071         printerr(2, "handling spkm3 upcall\n");
1072
1073         token.length = 0;
1074         token.value = NULL;
1075
1076         if (read(clp->spkm3_fd, &uid, sizeof(uid)) < sizeof(uid)) {
1077                 printerr(0, "WARNING: failed reading uid from spkm3 "
1078                          "upcall pipe: %s\n", strerror(errno));
1079                 goto out;
1080         }
1081
1082         if (create_auth_rpc_client(clp, &rpc_clnt, &auth, uid, AUTHTYPE_SPKM3)) {
1083                 printerr(0, "WARNING: Failed to create spkm3 context for "
1084                             "user with uid %d\n", uid);
1085                 goto out_return_error;
1086         }
1087
1088         if (!authgss_get_private_data(auth, &pd)) {
1089                 printerr(0, "WARNING: Failed to obtain authentication "
1090                             "data for user with uid %d for server %s\n",
1091                          uid, clp->servername);
1092                 goto out_return_error;
1093         }
1094
1095         if (serialize_context_for_kernel(pd.pd_ctx, &token, &spkm3oid)) {
1096                 printerr(0, "WARNING: Failed to serialize spkm3 context for "
1097                             "user with uid %d for server\n",
1098                          uid, clp->servername);
1099                 goto out_return_error;
1100         }
1101
1102         do_downcall(clp->spkm3_fd, uid, &pd, &token);
1103
1104 out:
1105         if (token.value)
1106                 free(token.value);
1107         if (auth)
1108                 AUTH_DESTROY(auth);
1109         if (rpc_clnt)
1110                 clnt_destroy(rpc_clnt);
1111         return;
1112
1113 out_return_error:
1114         do_error_downcall(clp->spkm3_fd, uid, -1);
1115         goto out;
1116 #endif
1117 }