Whamcloud - gitweb
land b_hd_pag: rudiment support for PAG.
[fs/lustre-release.git] / lustre / sec / gss / sec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * linux/net/sunrpc/auth_gss.c
12  *
13  * RPCSEC_GSS client authentication.
14  *
15  *  Copyright (c) 2000 The Regents of the University of Michigan.
16  *  All rights reserved.
17  *
18  *  Dug Song       <dugsong@monkey.org>
19  *  Andy Adamson   <andros@umich.edu>
20  *
21  *  Redistribution and use in source and binary forms, with or without
22  *  modification, are permitted provided that the following conditions
23  *  are met:
24  *
25  *  1. Redistributions of source code must retain the above copyright
26  *     notice, this list of conditions and the following disclaimer.
27  *  2. Redistributions in binary form must reproduce the above copyright
28  *     notice, this list of conditions and the following disclaimer in the
29  *     documentation and/or other materials provided with the distribution.
30  *  3. Neither the name of the University nor the names of its
31  *     contributors may be used to endorse or promote products derived
32  *     from this software without specific prior written permission.
33  *
34  *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
35  *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
36  *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37  *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
38  *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
39  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
40  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
41  *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
42  *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
43  *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
44  *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45  *
46  */
47
48 #ifndef EXPORT_SYMTAB
49 # define EXPORT_SYMTAB
50 #endif
51 #define DEBUG_SUBSYSTEM S_SEC
52 #ifdef __KERNEL__
53 #include <linux/init.h>
54 #include <linux/module.h>
55 #include <linux/slab.h>
56 #include <linux/dcache.h>
57 #include <linux/fs.h>
58 #include <linux/random.h>
59 /* for rpc_pipefs */
60 struct rpc_clnt;
61 #include <linux/sunrpc/rpc_pipe_fs.h>
62 #else
63 #include <liblustre.h>
64 #endif
65
66 #include <libcfs/kp30.h>
67 #include <linux/obd.h>
68 #include <linux/obd_class.h>
69 #include <linux/obd_support.h>
70 #include <linux/lustre_idl.h>
71 #include <linux/lustre_net.h>
72 #include <linux/lustre_import.h>
73 #include <linux/lustre_sec.h>
74
75 #include "gss_err.h"
76 #include "gss_internal.h"
77 #include "gss_api.h"
78
79 #define LUSTRE_PIPEDIR          "/lustre"
80
81 #define GSS_CREDCACHE_EXPIRE    (30 * 60)          /* 30 minute */
82
83 #define GSS_TIMEOUT_DELTA       (5)
84 #define CRED_REFRESH_UPCALL_TIMEOUT                             \
85         ({                                                      \
86                 int timeout = obd_timeout - GSS_TIMEOUT_DELTA;  \
87                                                                 \
88                 if (timeout < GSS_TIMEOUT_DELTA * 2)            \
89                         timeout = GSS_TIMEOUT_DELTA * 2;        \
90                 timeout;                                        \
91         })
92 #define SECINIT_RPC_TIMEOUT                                     \
93         ({                                                      \
94                 int timeout = CRED_REFRESH_UPCALL_TIMEOUT -     \
95                               GSS_TIMEOUT_DELTA;                \
96                 if (timeout < GSS_TIMEOUT_DELTA)                \
97                         timeout = GSS_TIMEOUT_DELTA;            \
98                 timeout;                                        \
99         })
100 #define SECFINI_RPC_TIMEOUT     (GSS_TIMEOUT_DELTA)
101
102
103 /**********************************************
104  * gss security init/fini helper              *
105  **********************************************/
106
107 static int secinit_compose_request(struct obd_import *imp,
108                                    char *buf, int bufsize,
109                                    int lustre_srv,
110                                    uid_t uid, gid_t gid,
111                                    long token_size,
112                                    char __user *token)
113 {
114         struct ptlrpcs_wire_hdr *hdr;
115         struct lustre_msg       *lmsg;
116         struct mds_req_sec_desc *secdesc;
117         int                      size = sizeof(*secdesc);
118         __u32                    lmsg_size, *p;
119         int                      rc;
120
121         lmsg_size = lustre_msg_size(1, &size);
122
123         if (sizeof(*hdr) + lmsg_size + size_round(token_size) > bufsize) {
124                 CERROR("token size %ld too large\n", token_size);
125                 return -EINVAL;
126         }
127
128         /* security wire hdr */
129         hdr = buf_to_sec_hdr(buf);
130         hdr->flavor  = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
131         hdr->msg_len = cpu_to_le32(lmsg_size);
132         hdr->sec_len = cpu_to_le32(8 * 4 + token_size);
133
134         /* lustre message & secdesc */
135         lmsg = buf_to_lustre_msg(buf);
136
137         lustre_init_msg(lmsg, 1, &size, NULL);
138         secdesc = lustre_msg_buf(lmsg, 0, size);
139         secdesc->rsd_uid = secdesc->rsd_fsuid = uid;
140         secdesc->rsd_gid = secdesc->rsd_fsgid = gid;
141         secdesc->rsd_cap = secdesc->rsd_ngroups = 0;
142
143         lmsg->handle   = imp->imp_remote_handle;
144         lmsg->type     = PTL_RPC_MSG_REQUEST;
145         lmsg->opc      = SEC_INIT;
146         lmsg->flags    = 0;
147         lmsg->conn_cnt = imp->imp_conn_cnt;
148
149         p = (__u32 *) (buf + sizeof(*hdr) + lmsg_size);
150
151         /* gss hdr */
152         *p++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);     /* gss version */
153         *p++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);         /* subflavor */
154         *p++ = cpu_to_le32(PTLRPCS_GSS_PROC_INIT);      /* proc */
155         *p++ = cpu_to_le32(0);                          /* seq */
156         *p++ = cpu_to_le32(PTLRPCS_GSS_SVC_NONE);       /* service */
157         *p++ = cpu_to_le32(0);                          /* context handle */
158
159         /* plus lustre svc type */
160         *p++ = cpu_to_le32(lustre_srv);
161
162         /* now the token part */
163         *p++ = cpu_to_le32((__u32) token_size);
164         LASSERT(((char *)p - buf) + token_size <= bufsize);
165
166         rc = copy_from_user(p, token, token_size);
167         if (rc) {
168                 CERROR("can't copy token\n");
169                 return -EFAULT;
170         }
171
172         rc = size_round(((char *)p - buf) + token_size);
173         return rc;
174 }
175
176 static int secinit_parse_reply(char *repbuf, int replen,
177                                char __user *outbuf, long outlen)
178 {
179         __u32                   *p = (__u32 *)repbuf;
180         struct ptlrpcs_wire_hdr *hdr = (struct ptlrpcs_wire_hdr *) repbuf;
181         __u32                    lmsg_len, sec_len, status;
182         __u32                    major, minor, seq, obj_len, round_len;
183         __u32                    effective = 0;
184
185         if (replen <= (4 + 6) * 4) {
186                 CERROR("reply size %d too small\n", replen);
187                 return -EINVAL;
188         }
189
190         hdr->flavor = le32_to_cpu(hdr->flavor);
191         hdr->msg_len = le32_to_cpu(hdr->msg_len);
192         hdr->sec_len = le32_to_cpu(hdr->sec_len);
193
194         lmsg_len = le32_to_cpu(p[2]);
195         sec_len = le32_to_cpu(p[3]);
196
197         /* sanity checks */
198         if (hdr->flavor != PTLRPCS_FLVR_GSS_NONE) {
199                 CERROR("unexpected reply\n");
200                 return -EINVAL;
201         }
202         if (hdr->msg_len % 8 ||
203             sizeof(*hdr) + hdr->msg_len + hdr->sec_len > replen) {
204                 CERROR("unexpected reply\n");
205                 return -EINVAL;
206         }
207         if (hdr->sec_len > outlen) {
208                 CERROR("outbuf too small\n");
209                 return -EINVAL;
210         }
211
212         p = (__u32 *) buf_to_sec_data(repbuf);
213         effective = 0;
214
215         status = le32_to_cpu(*p++);
216         major = le32_to_cpu(*p++);
217         minor = le32_to_cpu(*p++);
218         seq = le32_to_cpu(*p++);
219         effective += 4 * 4;
220
221         if (copy_to_user(outbuf, &status, 4))
222                 return -EFAULT;
223         outbuf += 4;
224         if (copy_to_user(outbuf, &major, 4))
225                 return -EFAULT;
226         outbuf += 4;
227         if (copy_to_user(outbuf, &minor, 4))
228                 return -EFAULT;
229         outbuf += 4;
230         if (copy_to_user(outbuf, &seq, 4))
231                 return -EFAULT;
232         outbuf += 4;
233
234         obj_len = le32_to_cpu(*p++);
235         round_len = (obj_len + 3) & ~ 3;
236         if (copy_to_user(outbuf, &obj_len, 4))
237                 return -EFAULT;
238         outbuf += 4;
239         if (copy_to_user(outbuf, (char *)p, round_len))
240                 return -EFAULT;
241         p += round_len / 4;
242         outbuf += round_len;
243         effective += 4 + round_len;
244
245         obj_len = le32_to_cpu(*p++);
246         round_len = (obj_len + 3) & ~ 3;
247         if (copy_to_user(outbuf, &obj_len, 4))
248                 return -EFAULT;
249         outbuf += 4;
250         if (copy_to_user(outbuf, (char *)p, round_len))
251                 return -EFAULT;
252         p += round_len / 4;
253         outbuf += round_len;
254         effective += 4 + round_len;
255
256         return effective;
257 }
258
259 /* XXX move to where lgssd could see */
260 struct lgssd_ioctl_param {
261         int             version;        /* in   */
262         char           *uuid;           /* in   */
263         int             lustre_svc;     /* in   */
264         uid_t           uid;            /* in   */
265         gid_t           gid;            /* in   */
266         long            send_token_size;/* in   */
267         char           *send_token;     /* in   */
268         long            reply_buf_size; /* in   */
269         char           *reply_buf;      /* in   */
270         long            status;         /* out  */
271         long            reply_length;   /* out  */
272 };
273
274 static int gss_send_secinit_rpc(__user char *buffer, unsigned long count)
275 {
276         struct obd_import        *imp;
277         struct ptlrpc_request    *request = NULL;
278         struct lgssd_ioctl_param  param;
279         const int                 reqbuf_size = 1024;
280         const int                 repbuf_size = 1024;
281         char                     *reqbuf, *repbuf;
282         struct obd_device        *obd;
283         char                      obdname[64];
284         long                      lsize;
285         int                       rc, reqlen, replen;
286
287         if (count != sizeof(param)) {
288                 CERROR("ioctl size %lu, expect %d, please check lgssd version\n",
289                         count, sizeof(param));
290                 RETURN(-EINVAL);
291         }
292         if (copy_from_user(&param, buffer, sizeof(param))) {
293                 CERROR("failed copy data from lgssd\n");
294                 RETURN(-EFAULT);
295         }
296
297         if (param.version != GSSD_INTERFACE_VERSION) {
298                 CERROR("gssd interface version %d (expect %d)\n",
299                         param.version, GSSD_INTERFACE_VERSION);
300                 RETURN(-EINVAL);
301         }
302
303         /* take name */
304         if (strncpy_from_user(obdname, param.uuid,
305                               sizeof(obdname)) <= 0) {
306                 CERROR("Invalid obdname pointer\n");
307                 RETURN(-EFAULT);
308         }
309
310         obd = class_name2obd(obdname);
311         if (!obd) {
312                 CERROR("no such obd %s\n", obdname);
313                 RETURN(-EINVAL);
314         }
315
316         imp = class_import_get(obd->u.cli.cl_import);
317
318         OBD_ALLOC(reqbuf, reqbuf_size);
319         OBD_ALLOC(repbuf, reqbuf_size);
320
321         if (!reqbuf || !repbuf) {
322                 CERROR("Can't alloc buffer: %p/%p\n", reqbuf, repbuf);
323                 param.status = -ENOMEM;
324                 goto out_copy;
325         }
326
327         /* get token */
328         reqlen = secinit_compose_request(imp, reqbuf, reqbuf_size,
329                                          param.lustre_svc,
330                                          param.uid, param.gid,
331                                          param.send_token_size,
332                                          param.send_token);
333         if (reqlen < 0) {
334                 param.status = reqlen;
335                 goto out_copy;
336         }
337
338         request = ptl_do_rawrpc(imp, reqbuf, reqbuf_size, reqlen,
339                                 repbuf, repbuf_size, &replen,
340                                 SECINIT_RPC_TIMEOUT, &rc);
341         if (request == NULL || rc) {
342                 param.status = rc;
343                 goto out_copy;
344         }
345
346         if (replen > param.reply_buf_size) {
347                 CERROR("output buffer size %ld too small, need %d\n",
348                         param.reply_buf_size, replen);
349                 param.status = -EINVAL;
350                 goto out_copy;
351         }
352
353         lsize = secinit_parse_reply(repbuf, replen,
354                                     param.reply_buf, param.reply_buf_size);
355         if (lsize < 0) {
356                 param.status = (int) lsize;
357                 goto out_copy;
358         }
359
360         param.status = 0;
361         param.reply_length = lsize;
362
363 out_copy:
364         if (copy_to_user(buffer, &param, sizeof(param)))
365                 rc = -EFAULT;
366         else
367                 rc = 0;
368
369         class_import_put(imp);
370         if (request == NULL) {
371                 if (repbuf)
372                         OBD_FREE(repbuf, repbuf_size);
373                 if (reqbuf)
374                         OBD_FREE(reqbuf, reqbuf_size);
375         } else {
376                 rawrpc_req_finished(request);
377         }
378         RETURN(rc);
379 }
380
381 /**********************************************
382  * structure definitions                      *
383  **********************************************/
384 struct gss_sec {
385         struct ptlrpc_sec       gs_base;
386         struct gss_api_mech    *gs_mech;
387 #ifdef __KERNEL__
388         spinlock_t              gs_lock;
389         struct list_head        gs_upcalls;
390         char                   *gs_pipepath;
391         struct dentry          *gs_depipe;
392 #endif
393 };
394
395 #ifdef __KERNEL__
396
397 static rwlock_t gss_ctx_lock = RW_LOCK_UNLOCKED;
398
399 struct gss_upcall_msg_data {
400         __u64                           gum_pag;
401         __u32                           gum_uid;
402         __u32                           gum_svc;
403         __u32                           gum_nal;
404         __u32                           gum_netid;
405         __u64                           gum_nid;
406 };
407
408 struct gss_upcall_msg {
409         struct rpc_pipe_msg             gum_base;
410         atomic_t                        gum_refcount;
411         struct list_head                gum_list;
412         struct gss_sec                 *gum_gsec;
413         wait_queue_head_t               gum_waitq;
414         char                            gum_obdname[64];
415         struct gss_upcall_msg_data      gum_data;
416 };
417
418 /**********************************************
419  * rpc_pipe upcall helpers                    *
420  **********************************************/
421 static
422 void gss_release_msg(struct gss_upcall_msg *gmsg)
423 {
424         ENTRY;
425         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
426
427         if (!atomic_dec_and_test(&gmsg->gum_refcount)) {
428                 CDEBUG(D_SEC, "gmsg %p ref %d\n", gmsg,
429                        atomic_read(&gmsg->gum_refcount));
430                 EXIT;
431                 return;
432         }
433         LASSERT(list_empty(&gmsg->gum_list));
434 #if 0
435         LASSERT(list_empty(&gmsg->gum_base.list));
436 #else
437         /* XXX */
438         if (!list_empty(&gmsg->gum_base.list)) {
439                 CWARN("msg %p: list: %p/%p/%p, copied %d, err %d, wq %d\n",
440                       gmsg, &gmsg->gum_base.list,
441                       gmsg->gum_base.list.prev, gmsg->gum_base.list.next,
442                       gmsg->gum_base.copied, gmsg->gum_base.errno,
443                       list_empty(&gmsg->gum_waitq.task_list));
444                 LBUG();
445         }
446 #endif
447         OBD_FREE(gmsg, sizeof(*gmsg));
448         EXIT;
449 }
450
451 static void
452 gss_unhash_msg_nolock(struct gss_upcall_msg *gmsg)
453 {
454 #if defined(CONFIG_SMP)
455         LASSERT(spin_is_locked(&gmsg->gum_gsec->gs_lock));
456 #endif
457
458         if (list_empty(&gmsg->gum_list))
459                 return;
460
461         list_del_init(&gmsg->gum_list);
462         wake_up(&gmsg->gum_waitq);
463         LASSERT(atomic_read(&gmsg->gum_refcount) > 1);
464         atomic_dec(&gmsg->gum_refcount);
465 }
466
467 static void
468 gss_unhash_msg(struct gss_upcall_msg *gmsg)
469 {
470         struct gss_sec *gsec = gmsg->gum_gsec;
471
472         spin_lock(&gsec->gs_lock);
473         gss_unhash_msg_nolock(gmsg);
474         spin_unlock(&gsec->gs_lock);
475 }
476
477 static
478 struct gss_upcall_msg * gss_find_upcall(struct gss_sec *gsec,
479                                         char *obdname,
480                                         struct gss_upcall_msg_data *gmd)
481 {
482         struct gss_upcall_msg *gmsg;
483         ENTRY;
484
485 #if defined(CONFIG_SMP)
486         LASSERT(spin_is_locked(&gsec->gs_lock));
487 #endif
488
489         list_for_each_entry(gmsg, &gsec->gs_upcalls, gum_list) {
490                 if (memcmp(&gmsg->gum_data, gmd, sizeof(*gmd)))
491                         continue;
492                 if (strcmp(gmsg->gum_obdname, obdname))
493                         continue;
494                 LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
495                 atomic_inc(&gmsg->gum_refcount);
496                 CDEBUG(D_SEC, "found gmsg at %p: obdname %s, uid %d, ref %d\n",
497                        gmsg, obdname, gmd->gum_uid,
498                        atomic_read(&gmsg->gum_refcount));
499                 RETURN(gmsg);
500         }
501         RETURN(NULL);
502 }
503
504 static void gss_init_upcall_msg(struct gss_upcall_msg *gmsg,
505                                 struct gss_sec *gsec, char *obdname,
506                                 struct gss_upcall_msg_data *gmd)
507 {
508         struct rpc_pipe_msg *rpcmsg;
509         ENTRY;
510
511         /* 2 refs: 1 for hash, 1 for current user */
512         init_waitqueue_head(&gmsg->gum_waitq);
513         list_add(&gmsg->gum_list, &gsec->gs_upcalls);
514         atomic_set(&gmsg->gum_refcount, 2);
515         gmsg->gum_gsec = gsec;
516         strncpy(gmsg->gum_obdname, obdname, sizeof(gmsg->gum_obdname));
517         memcpy(&gmsg->gum_data, gmd, sizeof(*gmd));
518
519         rpcmsg = &gmsg->gum_base;
520         INIT_LIST_HEAD(&rpcmsg->list);
521         rpcmsg->data = &gmsg->gum_data;
522         rpcmsg->len = sizeof(gmsg->gum_data);
523         rpcmsg->copied = 0;
524         rpcmsg->errno = 0;
525         EXIT;
526 }
527 #endif /* __KERNEL__ */
528
529 /********************************************
530  * gss cred manipulation helpers            *
531  ********************************************/
532 #if 0
533 static
534 int gss_cred_is_uptodate_ctx(struct ptlrpc_cred *cred)
535 {
536         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
537         int res = 0;
538
539         read_lock(&gss_ctx_lock);
540         if (((cred->pc_flags & PTLRPC_CRED_FLAGS_MASK) ==
541              PTLRPC_CRED_UPTODATE) &&
542             gcred->gc_ctx)
543                 res = 1;
544         read_unlock(&gss_ctx_lock);
545         return res;
546 }
547 #endif
548
549 static inline
550 struct gss_cl_ctx * gss_get_ctx(struct gss_cl_ctx *ctx)
551 {
552         atomic_inc(&ctx->gc_refcount);
553         return ctx;
554 }
555
556 static
557 void gss_destroy_ctx(struct gss_cl_ctx *ctx)
558 {
559         ENTRY;
560
561         CDEBUG(D_SEC, "destroy cl_ctx %p\n", ctx);
562         if (ctx->gc_gss_ctx)
563                 kgss_delete_sec_context(&ctx->gc_gss_ctx);
564
565         if (ctx->gc_wire_ctx.len > 0) {
566                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
567                 ctx->gc_wire_ctx.len = 0;
568         }
569
570         OBD_FREE(ctx, sizeof(*ctx));
571 }
572
573 static
574 void gss_put_ctx(struct gss_cl_ctx *ctx)
575 {
576         if (atomic_dec_and_test(&ctx->gc_refcount))
577                 gss_destroy_ctx(ctx);
578 }
579
580 static
581 struct gss_cl_ctx *gss_cred_get_ctx(struct ptlrpc_cred *cred)
582 {
583         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
584         struct gss_cl_ctx *ctx = NULL;
585
586         read_lock(&gss_ctx_lock);
587         if (gcred->gc_ctx)
588                 ctx = gss_get_ctx(gcred->gc_ctx);
589         read_unlock(&gss_ctx_lock);
590         return ctx;
591 }
592
593 static
594 void gss_cred_set_ctx(struct ptlrpc_cred *cred, struct gss_cl_ctx *ctx)
595 {
596         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
597         struct gss_cl_ctx *old;
598         __u64 ctx_expiry;
599         ENTRY;
600
601         if (kgss_inquire_context(ctx->gc_gss_ctx, &ctx_expiry)) {
602                 CERROR("unable to get expire time\n");
603                 ctx_expiry = 1; /* make it expired now */
604         }
605         cred->pc_expire = (unsigned long) ctx_expiry;
606
607         write_lock(&gss_ctx_lock);
608         old = gcred->gc_ctx;
609         gcred->gc_ctx = ctx;
610         set_bit(PTLRPC_CRED_UPTODATE_BIT, &cred->pc_flags);
611         write_unlock(&gss_ctx_lock);
612         if (old)
613                 gss_put_ctx(old);
614
615         CDEBUG(D_SEC, "client refreshed gss cred %p(uid %u)\n",
616                cred, cred->pc_uid);
617         EXIT;
618 }
619
620 static int
621 simple_get_bytes(char **buf, __u32 *buflen, void *res, __u32 reslen)
622 {
623         if (*buflen < reslen) {
624                 CERROR("buflen %u < %u\n", *buflen, reslen);
625                 return -EINVAL;
626         }
627
628         memcpy(res, *buf, reslen);
629         *buf += reslen;
630         *buflen -= reslen;
631         return 0;
632 }
633
634 /* data passed down:
635  *  - uid
636  *  - timeout
637  *  - gc_win / error
638  *  - wire_ctx (rawobj)
639  *  - mech_ctx? (rawobj)
640  */
641 static
642 int gss_parse_init_downcall(struct gss_api_mech *gm, rawobj_t *buf,
643                             struct gss_cl_ctx **gc,
644                             struct gss_upcall_msg_data *gmd, int *gss_err)
645 {
646         char *p = (char *)buf->data;
647         struct gss_cl_ctx *ctx;
648         __u32 len = buf->len;
649         unsigned int timeout;
650         rawobj_t tmp_buf;
651         int err = -EPERM;
652         ENTRY;
653
654         *gc = NULL;
655         *gss_err = 0;
656
657         OBD_ALLOC(ctx, sizeof(*ctx));
658         if (!ctx)
659                 RETURN(-ENOMEM);
660
661         ctx->gc_proc = RPC_GSS_PROC_DATA;
662         ctx->gc_seq = 0;
663         spin_lock_init(&ctx->gc_seq_lock);
664         atomic_set(&ctx->gc_refcount,1);
665
666         if (simple_get_bytes(&p, &len, &gmd->gum_pag, sizeof(gmd->gum_pag)))
667                 goto err_free_ctx;
668         if (simple_get_bytes(&p, &len, &gmd->gum_uid, sizeof(gmd->gum_uid)))
669                 goto err_free_ctx;
670         if (simple_get_bytes(&p, &len, &gmd->gum_svc, sizeof(gmd->gum_svc)))
671                 goto err_free_ctx;
672         if (simple_get_bytes(&p, &len, &gmd->gum_nal, sizeof(gmd->gum_nal)))
673                 goto err_free_ctx;
674         if (simple_get_bytes(&p, &len, &gmd->gum_netid, sizeof(gmd->gum_netid)))
675                 goto err_free_ctx;
676         if (simple_get_bytes(&p, &len, &gmd->gum_nid, sizeof(gmd->gum_nid)))
677                 goto err_free_ctx;
678         /* FIXME: discarded timeout for now */
679         if (simple_get_bytes(&p, &len, &timeout, sizeof(timeout)))
680                 goto err_free_ctx;
681         if (simple_get_bytes(&p, &len, &ctx->gc_win, sizeof(ctx->gc_win)))
682                 goto err_free_ctx;
683
684         /* lgssd signals an error by passing ctx->gc_win = 0: */
685         if (!ctx->gc_win) {
686                 /* in which case the next 2 int are:
687                  * - rpc error
688                  * - gss error
689                  */
690                 if (simple_get_bytes(&p, &len, &err, sizeof(err))) {
691                         err = -EPERM;
692                         goto err_free_ctx;
693                 }
694                 if (simple_get_bytes(&p, &len, gss_err, sizeof(*gss_err))) {
695                         err = -EPERM;
696                         goto err_free_ctx;
697                 }
698                 if (err == 0 && *gss_err == 0) {
699                         CERROR("no error passed from downcall\n");
700                         err = -EPERM;
701                 }
702                 goto err_free_ctx;
703         }
704
705         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
706                 goto err_free_ctx;
707         if (rawobj_dup(&ctx->gc_wire_ctx, &tmp_buf)) {
708                 err = -ENOMEM;
709                 goto err_free_ctx;
710         }
711         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
712                 goto err_free_wire_ctx;
713         if (len) {
714                 CERROR("unexpected trailing %u bytes\n", len);
715                 goto err_free_wire_ctx;
716         }
717         if (kgss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx))
718                 goto err_free_wire_ctx;
719
720         *gc = ctx;
721         RETURN(0);
722
723 err_free_wire_ctx:
724         if (ctx->gc_wire_ctx.data)
725                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
726 err_free_ctx:
727         OBD_FREE(ctx, sizeof(*ctx));
728         CDEBUG(D_SEC, "err_code %d, gss code %d\n", err, *gss_err);
729         return err;
730 }
731
732 /***************************************
733  * cred APIs                           *
734  ***************************************/
735 #ifdef __KERNEL__
736 static int gss_cred_refresh(struct ptlrpc_cred *cred)
737 {
738         struct obd_import          *import;
739         struct gss_sec             *gsec;
740         struct gss_upcall_msg      *gss_msg, *gss_new;
741         struct gss_upcall_msg_data  gmd;
742         struct dentry              *dentry;
743         char                       *obdname, *obdtype;
744         wait_queue_t                wait;
745         int                         res;
746         ENTRY;
747
748         might_sleep();
749
750         /* any flags means it has been handled, do nothing */
751         if (cred->pc_flags & PTLRPC_CRED_FLAGS_MASK)
752                 RETURN(0);
753
754         LASSERT(cred->pc_sec);
755         LASSERT(cred->pc_sec->ps_import);
756         LASSERT(cred->pc_sec->ps_import->imp_obd);
757
758         import = cred->pc_sec->ps_import;
759         if (!import->imp_connection) {
760                 CERROR("import has no connection set\n");
761                 RETURN(-EINVAL);
762         }
763
764         gmd.gum_pag = cred->pc_pag;
765         gmd.gum_uid = cred->pc_uid;
766         gmd.gum_nal = import->imp_connection->c_peer.peer_ni->pni_number;
767         gmd.gum_netid = 0;
768         gmd.gum_nid = import->imp_connection->c_peer.peer_id.nid;
769
770         obdtype = import->imp_obd->obd_type->typ_name;
771         if (!strcmp(obdtype, "mdc"))
772                 gmd.gum_svc = LUSTRE_GSS_SVC_MDS;
773         else if (!strcmp(obdtype, "osc"))
774                 gmd.gum_svc = LUSTRE_GSS_SVC_OSS;
775         else {
776                 CERROR("gss on %s?\n", obdtype);
777                 RETURN(-EINVAL);
778         }
779
780         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
781         obdname = import->imp_obd->obd_name;
782         dentry = gsec->gs_depipe;
783         gss_new = NULL;
784         res = 0;
785
786         CDEBUG(D_SEC, "Initiate gss context %p(%u@%s)\n",
787                container_of(cred, struct gss_cred, gc_base),
788                cred->pc_uid, import->imp_target_uuid.uuid);
789
790 again:
791         spin_lock(&gsec->gs_lock);
792         gss_msg = gss_find_upcall(gsec, obdname, &gmd);
793         if (gss_msg) {
794                 if (gss_new) {
795                         OBD_FREE(gss_new, sizeof(*gss_new));
796                         gss_new = NULL;
797                 }
798                 GOTO(waiting, res);
799         }
800
801         if (!gss_new) {
802                 spin_unlock(&gsec->gs_lock);
803                 OBD_ALLOC(gss_new, sizeof(*gss_new));
804                 if (!gss_new)
805                         RETURN(-ENOMEM);
806                 goto again;
807         }
808         /* so far we'v created gss_new */
809         gss_init_upcall_msg(gss_new, gsec, obdname, &gmd);
810
811         /* we'v created upcall msg, nobody else should touch the
812          * flag of this cred, unless be set as dead/expire by
813          * administrator via lctl etc.
814          */
815         if (cred->pc_flags & PTLRPC_CRED_FLAGS_MASK) {
816                 CWARN("cred %p("LPU64"/%u) was set flags %lx unexpectedly\n",
817                       cred, cred->pc_pag, cred->pc_uid, cred->pc_flags);
818                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
819                 gss_unhash_msg_nolock(gss_new);
820                 spin_unlock(&gsec->gs_lock);
821                 gss_release_msg(gss_new);
822                 RETURN(0);
823         }
824
825         /* need to make upcall now */
826         spin_unlock(&gsec->gs_lock);
827         res = rpc_queue_upcall(dentry->d_inode, &gss_new->gum_base);
828         if (res) {
829                 CERROR("rpc_queue_upcall failed: %d\n", res);
830                 gss_unhash_msg(gss_new);
831                 gss_release_msg(gss_new);
832                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
833                 RETURN(res);
834         }
835         gss_msg = gss_new;
836         spin_lock(&gsec->gs_lock);
837
838 waiting:
839         /* upcall might finish quickly */
840         if (list_empty(&gss_msg->gum_list)) {
841                 spin_unlock(&gsec->gs_lock);
842                 res = 0;
843                 goto out;
844         }
845
846         init_waitqueue_entry(&wait, current);
847         set_current_state(TASK_INTERRUPTIBLE);
848         add_wait_queue(&gss_msg->gum_waitq, &wait);
849         spin_unlock(&gsec->gs_lock);
850
851         if (gss_new)
852                 res = schedule_timeout(CRED_REFRESH_UPCALL_TIMEOUT * HZ);
853         else {
854                 schedule();
855                 res = 0;
856         }
857
858         remove_wait_queue(&gss_msg->gum_waitq, &wait);
859
860         /* - the one who refresh the cred for us should also be responsible
861          *   to set the status of cred, we can simply return.
862          * - if cred flags has been set, we also don't need to do that again,
863          *   no matter signal pending or timeout etc.
864          */
865         if (!gss_new || cred->pc_flags & PTLRPC_CRED_FLAGS_MASK)
866                 goto out;
867
868         if (signal_pending(current)) {
869                 CERROR("%s: cred %p: interrupted upcall\n",
870                        current->comm, cred);
871                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
872                 res = -EINTR;
873         } else if (res == 0) {
874                 CERROR("cred %p: upcall timedout\n", cred);
875                 set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
876                 res = -ETIMEDOUT;
877         } else
878                 res = 0;
879
880 out:
881         gss_release_msg(gss_msg);
882
883         RETURN(res);
884 }
885 #else /* !__KERNEL__ */
886 extern int lgss_handle_krb5_upcall(uid_t uid, __u32 dest_ip,
887                                    char *obd_name,
888                                    char *buf, int bufsize,
889                                    int (*callback)(char*, unsigned long));
890
891 static int gss_cred_refresh(struct ptlrpc_cred *cred)
892 {
893         char                    buf[4096];
894         rawobj_t                obj;
895         struct obd_import      *imp;
896         struct gss_sec         *gsec;
897         struct gss_api_mech    *mech;
898         struct gss_cl_ctx      *ctx = NULL;
899         struct vfs_cred         vcred = { 0 };
900         ptl_nid_t               peer_nid;
901         __u32                   dest_ip;
902         __u32                   subflavor;
903         int                     rc, gss_err;
904
905         LASSERT(cred);
906         LASSERT(cred->pc_sec);
907         LASSERT(cred->pc_sec->ps_import);
908         LASSERT(cred->pc_sec->ps_import->imp_obd);
909
910         if (ptlrpcs_cred_is_uptodate(cred))
911                 RETURN(0);
912
913         imp = cred->pc_sec->ps_import;
914         peer_nid = imp->imp_connection->c_peer.peer_id.nid;
915         dest_ip = (__u32) (peer_nid & 0xFFFFFFFF);
916         subflavor = cred->pc_sec->ps_flavor.subflavor;
917
918         if (subflavor != PTLRPC_SEC_GSS_KRB5I) {
919                 CERROR("unknown subflavor %u\n", subflavor);
920                 GOTO(err_out, rc = -EINVAL);
921         }
922
923         rc = lgss_handle_krb5_upcall(cred->pc_uid, dest_ip,
924                                      imp->imp_obd->obd_name,
925                                      buf, sizeof(buf),
926                                      gss_send_secinit_rpc);
927         LASSERT(rc != 0);
928         if (rc < 0)
929                 goto err_out;
930
931         obj.data = buf;
932         obj.len = rc;
933
934         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
935         mech = gsec->gs_mech;
936         LASSERT(mech);
937         rc = gss_parse_init_downcall(mech, &obj, &ctx, &vcred, &dest_ip,
938                                      &gss_err);
939         if (rc || gss_err) {
940                 CERROR("parse init downcall: rpc %d, gss 0x%x\n", rc, gss_err);
941                 if (rc != -ERESTART || gss_err != 0)
942                         set_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags);
943                 if (rc == 0)
944                         rc = -EPERM;
945                 goto err_out;
946         }
947
948         LASSERT(ctx);
949         gss_cred_set_ctx(cred, ctx);
950         LASSERT(gss_cred_is_uptodate_ctx(cred));
951
952         return 0;
953 err_out:
954         set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
955         return rc;
956 }
957 #endif
958
959 static int gss_cred_match(struct ptlrpc_cred *cred,
960                           struct vfs_cred *vcred)
961 {
962         RETURN(cred->pc_pag == vcred->vc_pag);
963 }
964
965 static int gss_cred_sign(struct ptlrpc_cred *cred,
966                          struct ptlrpc_request *req)
967 {
968         struct gss_cred         *gcred;
969         struct gss_cl_ctx       *ctx;
970         rawobj_t                 lmsg, mic;
971         __u32                   *vp, *vpsave, vlen, seclen;
972         __u32                    seqnum, major, rc = 0;
973         ENTRY;
974
975         LASSERT(req->rq_reqbuf);
976         LASSERT(req->rq_cred == cred);
977
978         gcred = container_of(cred, struct gss_cred, gc_base);
979         ctx = gss_cred_get_ctx(cred);
980         if (!ctx) {
981                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
982                         cred, cred->pc_pag, cred->pc_uid);
983                 RETURN(-EPERM);
984         }
985
986         lmsg.len = req->rq_reqlen;
987         lmsg.data = (__u8 *) req->rq_reqmsg;
988
989         vp = (__u32 *) (lmsg.data + lmsg.len);
990         vlen = req->rq_reqbuf_len - sizeof(struct ptlrpcs_wire_hdr) -
991                lmsg.len;
992         seclen = vlen;
993
994         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
995                 CERROR("vlen %d, need %d\n",
996                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
997                 rc = -EIO;
998                 goto out;
999         }
1000
1001         spin_lock(&ctx->gc_seq_lock);
1002         seqnum = ctx->gc_seq++;
1003         spin_unlock(&ctx->gc_seq_lock);
1004
1005         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
1006         *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);        /* subflavor */
1007         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
1008         *vp++ = cpu_to_le32(seqnum);                    /* seq */
1009         *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_INTEGRITY); /* service */
1010         vlen -= 5 * 4;
1011
1012         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
1013                 rc = -EIO;
1014                 goto out;
1015         }
1016         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
1017
1018         vpsave = vp++;  /* reserve for size */
1019         vlen -= 4;
1020
1021         mic.len = vlen;
1022         mic.data = (unsigned char *)vp;
1023
1024         CDEBUG(D_SEC, "reqbuf at %p, lmsg at %p, len %d, mic at %p, len %d\n",
1025                req->rq_reqbuf, lmsg.data, lmsg.len, mic.data, mic.len);
1026         major = kgss_get_mic(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT, &lmsg, &mic);
1027         if (major) {
1028                 CERROR("cred %p: req %p compute mic error, major %x\n",
1029                        cred, req, major);
1030                 rc = -EACCES;
1031                 goto out;
1032         }
1033
1034         *vpsave = cpu_to_le32(mic.len);
1035         
1036         seclen = seclen - vlen + mic.len;
1037         buf_to_sec_hdr(req->rq_reqbuf)->sec_len = cpu_to_le32(seclen);
1038         req->rq_reqdata_len += size_round(seclen);
1039         CDEBUG(D_SEC, "msg size %d, checksum size %d, total sec size %d\n",
1040                lmsg.len, mic.len, seclen);
1041 out:
1042         gss_put_ctx(ctx);
1043         RETURN(rc);
1044 }
1045
1046 static int gss_cred_verify(struct ptlrpc_cred *cred,
1047                            struct ptlrpc_request *req)
1048 {
1049         struct gss_cred        *gcred;
1050         struct gss_cl_ctx      *ctx;
1051         struct ptlrpcs_wire_hdr *sec_hdr;
1052         rawobj_t                lmsg, mic;
1053         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1054         __u32                   major, minor, rc;
1055         ENTRY;
1056
1057         LASSERT(req->rq_repbuf);
1058         LASSERT(req->rq_cred == cred);
1059
1060         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1061         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr) + sec_hdr->msg_len);
1062         vlen = sec_hdr->sec_len;
1063
1064         if (vlen < 7 * 4) {
1065                 CERROR("reply sec size %u too small\n", vlen);
1066                 RETURN(-EPROTO);
1067         }
1068
1069         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1070                 CERROR("reply have different gss version\n");
1071                 RETURN(-EPROTO);
1072         }
1073         subflavor = le32_to_cpu(*vp++);
1074         proc = le32_to_cpu(*vp++);
1075         vlen -= 3 * 4;
1076
1077         switch (proc) {
1078         case PTLRPCS_GSS_PROC_DATA:
1079                 seq = le32_to_cpu(*vp++);
1080                 svc = le32_to_cpu(*vp++);
1081                 if (svc != PTLRPCS_GSS_SVC_INTEGRITY) {
1082                         CERROR("Unknown svc %d\n", svc);
1083                         RETURN(-EPROTO);
1084                 }
1085                 if (*vp++ != 0) {
1086                         CERROR("Unexpected ctx handle\n");
1087                         RETURN(-EPROTO);
1088                 }
1089                 mic.len = le32_to_cpu(*vp++);
1090                 vlen -= 4 * 4;
1091                 if (vlen < mic.len) {
1092                         CERROR("vlen %d, mic.len %d\n", vlen, mic.len);
1093                         RETURN(-EINVAL);
1094                 }
1095                 mic.data = (unsigned char *)vp;
1096
1097                 gcred = container_of(cred, struct gss_cred, gc_base);
1098                 ctx = gss_cred_get_ctx(cred);
1099                 LASSERT(ctx);
1100
1101                 lmsg.len = sec_hdr->msg_len;
1102                 lmsg.data = (__u8 *) buf_to_lustre_msg(req->rq_repbuf);
1103
1104                 major = kgss_verify_mic(ctx->gc_gss_ctx, &lmsg, &mic, NULL);
1105                 if (major != GSS_S_COMPLETE) {
1106                         CERROR("cred %p: req %p verify mic error: major %x\n",
1107                                cred, req, major);
1108
1109                         if (major == GSS_S_CREDENTIALS_EXPIRED ||
1110                             major == GSS_S_CONTEXT_EXPIRED) {
1111                                 ptlrpcs_cred_expire(cred);
1112                                 req->rq_ptlrpcs_restart = 1;
1113                                 rc = 0;
1114                         } else
1115                                 rc = -EINVAL;
1116
1117                         GOTO(proc_data_out, rc);
1118                 }
1119
1120                 req->rq_repmsg = (struct lustre_msg *) lmsg.data;
1121                 req->rq_replen = lmsg.len;
1122
1123                 /* here we could check the seq number is the same one
1124                  * we sent to server. but portals has prevent us from
1125                  * replay attack, so maybe we don't need check it again.
1126                  */
1127                 rc = 0;
1128 proc_data_out:
1129                 gss_put_ctx(ctx);
1130                 break;
1131         case PTLRPCS_GSS_PROC_ERR:
1132                 major = le32_to_cpu(*vp++);
1133                 minor = le32_to_cpu(*vp++);
1134                 /* server return NO_CONTEXT might be caused by context expire
1135                  * or server reboot/failover. we refresh the cred transparently
1136                  * to upper layer.
1137                  * In some cases, our gss handle is possible to be incidentally
1138                  * identical to another handle since the handle itself is not
1139                  * fully random. In krb5 case, the GSS_S_BAD_SIG will be
1140                  * returned, maybe other gss error for other mechanism. Here we
1141                  * only consider krb5 mech (FIXME) and try to establish new
1142                  * context.
1143                  */
1144                 if (major == GSS_S_NO_CONTEXT ||
1145                     major == GSS_S_BAD_SIG) {
1146                         CWARN("req %p: server report cred %p %s, expired?\n",
1147                                req, cred, (major == GSS_S_NO_CONTEXT) ?
1148                                            "NO_CONTEXT" : "BAD_SIG");
1149
1150                         ptlrpcs_cred_expire(cred);
1151                         req->rq_ptlrpcs_restart = 1;
1152                         rc = 0;
1153                 } else {
1154                         CERROR("req %p: unrecognized gss error (%x/%x)\n",
1155                                 req, major, minor);
1156                         rc = -EACCES;
1157                 }
1158                 break;
1159         default:
1160                 CERROR("unknown gss proc %d\n", proc);
1161                 rc = -EPROTO;
1162         }
1163
1164         RETURN(rc);
1165 }
1166
1167 static int gss_cred_seal(struct ptlrpc_cred *cred,
1168                          struct ptlrpc_request *req)
1169 {
1170         struct gss_cred         *gcred;
1171         struct gss_cl_ctx       *ctx;
1172         struct ptlrpcs_wire_hdr *sec_hdr;
1173         rawobj_buf_t             msg_buf;
1174         rawobj_t                 cipher_buf;
1175         __u32                   *vp, *vpsave, vlen, seclen;
1176         __u32                    major, seqnum, rc = 0;
1177         ENTRY;
1178
1179         LASSERT(req->rq_reqbuf);
1180         LASSERT(req->rq_cred == cred);
1181
1182         gcred = container_of(cred, struct gss_cred, gc_base);
1183         ctx = gss_cred_get_ctx(cred);
1184         if (!ctx) {
1185                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
1186                         cred, cred->pc_pag, cred->pc_uid);
1187                 RETURN(-EPERM);
1188         }
1189
1190         vp = (__u32 *) (req->rq_reqbuf + sizeof(*sec_hdr));
1191         vlen = req->rq_reqbuf_len - sizeof(*sec_hdr);
1192         seclen = vlen;
1193
1194         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
1195                 CERROR("vlen %d, need %d\n",
1196                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
1197                 rc = -EIO;
1198                 goto out;
1199         }
1200
1201         spin_lock(&ctx->gc_seq_lock);
1202         seqnum = ctx->gc_seq++;
1203         spin_unlock(&ctx->gc_seq_lock);
1204
1205         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
1206         *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5P);        /* subflavor */
1207         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
1208         *vp++ = cpu_to_le32(seqnum);                    /* seq */
1209         *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_PRIVACY);   /* service */
1210         vlen -= 5 * 4;
1211
1212         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
1213                 rc = -EIO;
1214                 goto out;
1215         }
1216         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
1217
1218         vpsave = vp++;  /* reserve for size */
1219         vlen -= 4;
1220
1221         msg_buf.buf = (__u8 *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1222         msg_buf.buflen = req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1223                                           GSS_PRIVBUF_SUFFIX_LEN;
1224         msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1225         msg_buf.datalen = req->rq_reqlen;
1226
1227         cipher_buf.data = (__u8 *) vp;
1228         cipher_buf.len = vlen;
1229
1230         major = kgss_wrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1231                           &msg_buf, &cipher_buf);
1232         if (major) {
1233                 CERROR("cred %p: error wrap: major 0x%x\n", cred, major);
1234                 GOTO(out, rc = -EINVAL);
1235         }
1236
1237         *vpsave = cpu_to_le32(cipher_buf.len);
1238
1239         seclen = seclen - vlen + cipher_buf.len;
1240         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1241         sec_hdr->sec_len = cpu_to_le32(seclen);
1242         req->rq_reqdata_len += size_round(seclen);
1243
1244         CDEBUG(D_SEC, "msg size %d, total sec size %d\n",
1245                req->rq_reqlen, seclen);
1246 out:
1247         gss_put_ctx(ctx);
1248         RETURN(rc);
1249 }
1250
1251 static int gss_cred_unseal(struct ptlrpc_cred *cred,
1252                            struct ptlrpc_request *req)
1253 {
1254         struct gss_cred        *gcred;
1255         struct gss_cl_ctx      *ctx;
1256         struct ptlrpcs_wire_hdr *sec_hdr;
1257         rawobj_t                cipher_text, plain_text;
1258         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1259         __u32                   major, rc;
1260         ENTRY;
1261
1262         LASSERT(req->rq_repbuf);
1263         LASSERT(req->rq_cred == cred);
1264
1265         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1266         if (sec_hdr->msg_len != 0) {
1267                 CERROR("unexpected msg_len %u\n", sec_hdr->msg_len);
1268                 RETURN(-EPROTO);
1269         }
1270
1271         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr));
1272         vlen = sec_hdr->sec_len;
1273
1274         if (vlen < 7 * 4) {
1275                 CERROR("reply sec size %u too small\n", vlen);
1276                 RETURN(-EPROTO);
1277         }
1278
1279         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1280                 CERROR("reply have different gss version\n");
1281                 RETURN(-EPROTO);
1282         }
1283         subflavor = le32_to_cpu(*vp++);
1284         proc = le32_to_cpu(*vp++);
1285         seq = le32_to_cpu(*vp++);
1286         svc = le32_to_cpu(*vp++);
1287         vlen -= 5 * 4;
1288
1289         switch (proc) {
1290         case PTLRPCS_GSS_PROC_DATA:
1291                 if (svc != PTLRPCS_GSS_SVC_PRIVACY) {
1292                         CERROR("Unknown svc %d\n", svc);
1293                         RETURN(-EPROTO);
1294                 }
1295                 if (*vp++ != 0) {
1296                         CERROR("Unexpected ctx handle\n");
1297                         RETURN(-EPROTO);
1298                 }
1299                 vlen -= 4;
1300
1301                 cipher_text.len = le32_to_cpu(*vp++);
1302                 cipher_text.data = (__u8 *) vp;
1303                 vlen -= 4;
1304
1305                 if (vlen < cipher_text.len) {
1306                         CERROR("cipher text to be %u while buf only %u\n",
1307                                 cipher_text.len, vlen);
1308                         RETURN(-EPROTO);
1309                 }
1310
1311                 plain_text = cipher_text;
1312
1313                 gcred = container_of(cred, struct gss_cred, gc_base);
1314                 ctx = gss_cred_get_ctx(cred);
1315                 LASSERT(ctx);
1316
1317                 major = kgss_unwrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1318                                     &cipher_text, &plain_text);
1319                 if (major) {
1320                         CERROR("cred %p: error unwrap: major 0x%x\n",
1321                                cred, major);
1322
1323                         if (major == GSS_S_CREDENTIALS_EXPIRED ||
1324                             major == GSS_S_CONTEXT_EXPIRED) {
1325                                 ptlrpcs_cred_expire(cred);
1326                                 req->rq_ptlrpcs_restart = 1;
1327                                 rc = 0;
1328                         } else
1329                                 rc = -EINVAL;
1330
1331                         GOTO(proc_out, rc);
1332                 }
1333
1334                 req->rq_repmsg = (struct lustre_msg *) vp;
1335                 req->rq_replen = plain_text.len;
1336
1337                 rc = 0;
1338 proc_out:
1339                 gss_put_ctx(ctx);
1340                 break;
1341         default:
1342                 CERROR("unknown gss proc %d\n", proc);
1343                 rc = -EPROTO;
1344         }
1345
1346         RETURN(rc);
1347 }
1348
1349 static void destroy_gss_context(struct ptlrpc_cred *cred)
1350 {
1351         struct ptlrpcs_wire_hdr *hdr;
1352         struct lustre_msg       *lmsg;
1353         struct gss_cred         *gcred;
1354         struct ptlrpc_request    req;
1355         struct obd_import       *imp;
1356         __u32                   *vp, lmsg_size;
1357         struct ptlrpc_request   *raw_req = NULL;
1358         const int                repbuf_len = 256;
1359         char                    *repbuf;
1360         int                      replen, rc;
1361         ENTRY;
1362
1363         imp = cred->pc_sec->ps_import;
1364         LASSERT(imp);
1365
1366         if (test_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags) ||
1367             !test_bit(PTLRPC_CRED_UPTODATE_BIT, &cred->pc_flags)) {
1368                 CDEBUG(D_SEC, "Destroy dead cred %p(%u@%s)\n",
1369                        cred, cred->pc_uid, imp->imp_target_uuid.uuid);
1370                 EXIT;
1371                 return;
1372         }
1373
1374         might_sleep();
1375
1376         /* cred's refcount is 0, steal one */
1377         atomic_inc(&cred->pc_refcount);
1378
1379         gcred = container_of(cred, struct gss_cred, gc_base);
1380         gcred->gc_ctx->gc_proc = PTLRPCS_GSS_PROC_DESTROY;
1381
1382         CDEBUG(D_SEC, "client destroy gss cred %p(%u@%s)\n",
1383                gcred, cred->pc_uid, imp->imp_target_uuid.uuid);
1384
1385         lmsg_size = lustre_msg_size(0, NULL);
1386         req.rq_req_secflvr = cred->pc_sec->ps_flavor;
1387         req.rq_cred = cred;
1388         req.rq_reqbuf_len = sizeof(*hdr) + lmsg_size +
1389                             ptlrpcs_est_req_payload(&req, lmsg_size);
1390
1391         OBD_ALLOC(req.rq_reqbuf, req.rq_reqbuf_len);
1392         if (!req.rq_reqbuf) {
1393                 CERROR("Fail to alloc reqbuf, cancel anyway\n");
1394                 atomic_dec(&cred->pc_refcount);
1395                 EXIT;
1396                 return;
1397         }
1398
1399         /* wire hdr */
1400         hdr = buf_to_sec_hdr(req.rq_reqbuf);
1401         hdr->flavor  = cpu_to_le32(PTLRPCS_FLVR_GSS_AUTH);
1402         hdr->msg_len = cpu_to_le32(lmsg_size);
1403         hdr->sec_len = cpu_to_le32(0);
1404
1405         /* lustre message */
1406         lmsg = buf_to_lustre_msg(req.rq_reqbuf);
1407         lustre_init_msg(lmsg, 0, NULL, NULL);
1408         lmsg->handle   = imp->imp_remote_handle;
1409         lmsg->type     = PTL_RPC_MSG_REQUEST;
1410         lmsg->opc      = SEC_FINI;
1411         lmsg->flags    = 0;
1412         lmsg->conn_cnt = imp->imp_conn_cnt;
1413         /* add this for randomize */
1414         get_random_bytes(&lmsg->last_xid, sizeof(lmsg->last_xid));
1415         get_random_bytes(&lmsg->transno, sizeof(lmsg->transno));
1416
1417         vp = (__u32 *) req.rq_reqbuf;
1418
1419         req.rq_cred = cred;
1420         req.rq_reqmsg = buf_to_lustre_msg(req.rq_reqbuf);
1421         req.rq_reqlen = lmsg_size;
1422         req.rq_reqdata_len = sizeof(*hdr) + lmsg_size;
1423
1424         if (gss_cred_sign(cred, &req)) {
1425                 CERROR("failed to sign, cancel anyway\n");
1426                 atomic_dec(&cred->pc_refcount);
1427                 goto exit;
1428         }
1429         atomic_dec(&cred->pc_refcount);
1430
1431         OBD_ALLOC(repbuf, repbuf_len);
1432         if (!repbuf)
1433                 goto exit;
1434
1435         raw_req = ptl_do_rawrpc(imp, req.rq_reqbuf, req.rq_reqbuf_len,
1436                                 req.rq_reqdata_len, repbuf, repbuf_len, &replen,
1437                                 SECFINI_RPC_TIMEOUT, &rc);
1438         if (!raw_req)
1439                 OBD_FREE(repbuf, repbuf_len);
1440
1441 exit:
1442         if (raw_req == NULL)
1443                 OBD_FREE(req.rq_reqbuf, req.rq_reqbuf_len);
1444         else
1445                 rawrpc_req_finished(raw_req);
1446         EXIT;
1447 }
1448
1449 static void gss_cred_destroy(struct ptlrpc_cred *cred)
1450 {
1451         struct gss_cred *gcred;
1452         ENTRY;
1453
1454         LASSERT(cred);
1455         LASSERT(!atomic_read(&cred->pc_refcount));
1456
1457         gcred = container_of(cred, struct gss_cred, gc_base);
1458         if (gcred->gc_ctx) {
1459                 destroy_gss_context(cred);
1460                 gss_put_ctx(gcred->gc_ctx);
1461         }
1462
1463         CDEBUG(D_SEC, "sec.gss %p: destroy cred %p\n", cred->pc_sec, gcred);
1464
1465         OBD_FREE(gcred, sizeof(*gcred));
1466         EXIT;
1467 }
1468
1469 static struct ptlrpc_credops gss_credops = {
1470         .refresh        = gss_cred_refresh,
1471         .match          = gss_cred_match,
1472         .sign           = gss_cred_sign,
1473         .verify         = gss_cred_verify,
1474         .seal           = gss_cred_seal,
1475         .unseal         = gss_cred_unseal,
1476         .destroy        = gss_cred_destroy,
1477 };
1478
1479 #ifdef __KERNEL__
1480 /*******************************************
1481  * rpc_pipe APIs                           *
1482  *******************************************/
1483 static ssize_t
1484 gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
1485                 char *dst, size_t buflen)
1486 {
1487         char *data = (char *)msg->data + msg->copied;
1488         ssize_t mlen = msg->len;
1489         ssize_t left;
1490         ENTRY;
1491
1492         if (mlen > buflen)
1493                 mlen = buflen;
1494         left = copy_to_user(dst, data, mlen);
1495         if (left < 0) {
1496                 msg->errno = left;
1497                 RETURN(left);
1498         }
1499         mlen -= left;
1500         msg->copied += mlen;
1501         msg->errno = 0;
1502         RETURN(mlen);
1503 }
1504
1505 static ssize_t
1506 gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
1507 {
1508         char *buf;
1509         const int bufsize = 1024;
1510         rawobj_t obj;
1511         struct inode *inode = filp->f_dentry->d_inode;
1512         struct rpc_inode *rpci = RPC_I(inode);
1513         struct obd_import *import;
1514         struct ptlrpc_sec *sec;
1515         struct gss_sec *gsec;
1516         char *obdname;
1517         struct gss_api_mech *mech;
1518         struct vfs_cred vcred = { 0 };
1519         struct ptlrpc_cred *cred;
1520         struct gss_upcall_msg *gss_msg;
1521         struct gss_upcall_msg_data gmd = { 0 };
1522         struct gss_cl_ctx *ctx = NULL;
1523         ssize_t left;
1524         int err, gss_err;
1525         ENTRY;
1526
1527         if (mlen > bufsize) {
1528                 CERROR("mlen %ld > bufsize %d\n", (long)mlen, bufsize);
1529                 RETURN(-ENOSPC);
1530         }
1531
1532         OBD_ALLOC(buf, bufsize);
1533         if (!buf) {
1534                 CERROR("alloc mem failed\n");
1535                 RETURN(-ENOMEM);
1536         }
1537
1538         left = copy_from_user(buf, src, mlen);
1539         if (left)
1540                 GOTO(err_free, err = -EFAULT);
1541
1542         obj.data = (unsigned char *)buf;
1543         obj.len = mlen;
1544
1545         LASSERT(rpci->private);
1546         gsec = (struct gss_sec *)rpci->private;
1547         sec = &gsec->gs_base;
1548         LASSERT(sec->ps_import);
1549         import = class_import_get(sec->ps_import);
1550         LASSERT(import->imp_obd);
1551         obdname = import->imp_obd->obd_name;
1552         mech = gsec->gs_mech;
1553
1554         err = gss_parse_init_downcall(mech, &obj, &ctx, &gmd, &gss_err);
1555         if (err)
1556                 CERROR("parse init downcall err %d\n", err);
1557
1558         vcred.vc_pag = gmd.gum_pag;
1559         vcred.vc_uid = gmd.gum_uid;
1560
1561         cred = ptlrpcs_cred_lookup(sec, &vcred);
1562         if (!cred) {
1563                 CWARN("didn't find cred for uid %u\n", vcred.vc_uid);
1564                 GOTO(err, err = -EINVAL);
1565         }
1566
1567         if (err || gss_err) {
1568                 set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
1569                 if (err != -ERESTART || gss_err != 0)
1570                         set_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags);
1571                 CERROR("cred %p: rpc err %d, gss err 0x%x, fatal %d\n",
1572                        cred, err, gss_err,
1573                        (test_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags) != 0));
1574         } else {
1575                 CDEBUG(D_SEC, "get initial ctx:\n");
1576                 gss_cred_set_ctx(cred, ctx);
1577         }
1578
1579         spin_lock(&gsec->gs_lock);
1580         gss_msg = gss_find_upcall(gsec, obdname, &gmd);
1581         if (gss_msg) {
1582                 gss_unhash_msg_nolock(gss_msg);
1583                 spin_unlock(&gsec->gs_lock);
1584                 gss_release_msg(gss_msg);
1585         } else
1586                 spin_unlock(&gsec->gs_lock);
1587
1588         ptlrpcs_cred_put(cred, 1);
1589         class_import_put(import);
1590         OBD_FREE(buf, bufsize);
1591         RETURN(mlen);
1592 err:
1593         if (ctx)
1594                 gss_destroy_ctx(ctx);
1595         class_import_put(import);
1596 err_free:
1597         OBD_FREE(buf, bufsize);
1598         CDEBUG(D_SEC, "gss_pipe_downcall returning %d\n", err);
1599         RETURN(err);
1600 }
1601
1602 static
1603 void gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
1604 {
1605         struct gss_upcall_msg *gmsg;
1606         static unsigned long ratelimit;
1607         ENTRY;
1608
1609         LASSERT(list_empty(&msg->list));
1610
1611         if (msg->errno >= 0) {
1612                 EXIT;
1613                 return;
1614         }
1615
1616         gmsg = container_of(msg, struct gss_upcall_msg, gum_base);
1617         CDEBUG(D_SEC, "destroy gmsg %p\n", gmsg);
1618         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
1619         atomic_inc(&gmsg->gum_refcount);
1620         gss_unhash_msg(gmsg);
1621         if (msg->errno == -ETIMEDOUT || msg->errno == -EPIPE) {
1622                 unsigned long now = get_seconds();
1623                 if (time_after(now, ratelimit)) {
1624                         CWARN("sec.gss upcall timed out.\n"
1625                               "Please check user daemon is running!\n");
1626                         ratelimit = now + 15;
1627                 }
1628         }
1629         gss_release_msg(gmsg);
1630         EXIT;
1631 }
1632
1633 static
1634 void gss_pipe_release(struct inode *inode)
1635 {
1636         struct rpc_inode *rpci = RPC_I(inode);
1637         struct ptlrpc_sec *sec;
1638         struct gss_sec *gsec;
1639         ENTRY;
1640
1641         gsec = (struct gss_sec *)rpci->private;
1642         sec = &gsec->gs_base;
1643         spin_lock(&gsec->gs_lock);
1644         while (!list_empty(&gsec->gs_upcalls)) {
1645                 struct gss_upcall_msg *gmsg;
1646
1647                 gmsg = list_entry(gsec->gs_upcalls.next,
1648                                   struct gss_upcall_msg, gum_list);
1649                 LASSERT(list_empty(&gmsg->gum_base.list));
1650                 gmsg->gum_base.errno = -EPIPE;
1651                 atomic_inc(&gmsg->gum_refcount);
1652                 gss_unhash_msg_nolock(gmsg);
1653                 gss_release_msg(gmsg);
1654         }
1655         spin_unlock(&gsec->gs_lock);
1656         EXIT;
1657 }
1658
1659 static struct rpc_pipe_ops gss_upcall_ops = {
1660         .upcall         = gss_pipe_upcall,
1661         .downcall       = gss_pipe_downcall,
1662         .destroy_msg    = gss_pipe_destroy_msg,
1663         .release_pipe   = gss_pipe_release,
1664 };
1665 #endif /* __KERNEL__ */
1666
1667 /*********************************************
1668  * GSS security APIs                         *
1669  *********************************************/
1670
1671 static
1672 struct ptlrpc_sec* gss_create_sec(__u32 flavor,
1673                                   const char *pipe_dir,
1674                                   void *pipe_data)
1675 {
1676         struct gss_sec *gsec;
1677         struct ptlrpc_sec *sec;
1678 #ifdef __KERNEL__
1679         char *pos;
1680         int   pipepath_len;
1681         uid_t save_uid;
1682 #endif
1683         ENTRY;
1684
1685         LASSERT(SEC_FLAVOR_MAJOR(flavor) == PTLRPCS_FLVR_MAJOR_GSS);
1686
1687         OBD_ALLOC(gsec, sizeof(*gsec));
1688         if (!gsec) {
1689                 CERROR("can't alloc gsec\n");
1690                 RETURN(NULL);
1691         }
1692
1693         gsec->gs_mech = kgss_subflavor_to_mech(SEC_FLAVOR_SUB(flavor));
1694         if (!gsec->gs_mech) {
1695                 CERROR("subflavor 0x%x not found\n", flavor);
1696                 goto err_free;
1697         }
1698
1699         /* initialize gss sec */
1700 #ifdef __KERNEL__
1701         INIT_LIST_HEAD(&gsec->gs_upcalls);
1702         spin_lock_init(&gsec->gs_lock);
1703
1704         pipepath_len = strlen(LUSTRE_PIPEDIR) + strlen(pipe_dir) +
1705                        strlen(gsec->gs_mech->gm_name) + 3;
1706         OBD_ALLOC(gsec->gs_pipepath, pipepath_len);
1707         if (!gsec->gs_pipepath)
1708                 goto err_mech_put;
1709
1710         /* pipe rpc require root permission */
1711         save_uid = current->fsuid;
1712         current->fsuid = 0;
1713
1714         sprintf(gsec->gs_pipepath, LUSTRE_PIPEDIR"/%s", pipe_dir);
1715         if (IS_ERR(rpc_mkdir(gsec->gs_pipepath, NULL))) {
1716                 CERROR("can't make pipedir %s\n", gsec->gs_pipepath);
1717                 goto err_free_path;
1718         }
1719
1720         sprintf(gsec->gs_pipepath, LUSTRE_PIPEDIR"/%s/%s", pipe_dir,
1721                 gsec->gs_mech->gm_name); 
1722         gsec->gs_depipe = rpc_mkpipe(gsec->gs_pipepath, gsec,
1723                                      &gss_upcall_ops, RPC_PIPE_WAIT_FOR_OPEN);
1724         if (IS_ERR(gsec->gs_depipe)) {
1725                 CERROR("failed to make rpc_pipe %s: %ld\n",
1726                         gsec->gs_pipepath, PTR_ERR(gsec->gs_depipe));
1727                 goto err_rmdir;
1728         }
1729         CDEBUG(D_SEC, "gss sec %p, pipe path %s\n", gsec, gsec->gs_pipepath);
1730 #endif
1731
1732         sec = &gsec->gs_base;
1733         sec->ps_expire = GSS_CREDCACHE_EXPIRE;
1734         sec->ps_nextgc = get_seconds() + sec->ps_expire;
1735         sec->ps_flags = 0;
1736
1737         current->fsuid = save_uid;
1738
1739         CDEBUG(D_SEC, "Create sec.gss %p\n", gsec);
1740         RETURN(sec);
1741
1742 #ifdef __KERNEL__
1743 err_rmdir:
1744         pos = strrchr(gsec->gs_pipepath, '/');
1745         LASSERT(pos);
1746         *pos = 0;
1747         rpc_rmdir(gsec->gs_pipepath);
1748 err_free_path:
1749         current->fsuid = save_uid;
1750         OBD_FREE(gsec->gs_pipepath, pipepath_len);
1751 err_mech_put:
1752 #endif
1753         kgss_mech_put(gsec->gs_mech);
1754 err_free:
1755         OBD_FREE(gsec, sizeof(*gsec));
1756         RETURN(NULL);
1757 }
1758
1759 static
1760 void gss_destroy_sec(struct ptlrpc_sec *sec)
1761 {
1762         struct gss_sec *gsec;
1763 #ifdef __KERNEL__
1764         char *pos;
1765         int   pipepath_len;
1766 #endif
1767         ENTRY;
1768
1769         gsec = container_of(sec, struct gss_sec, gs_base);
1770         CDEBUG(D_SEC, "Destroy sec.gss %p\n", gsec);
1771
1772         LASSERT(gsec->gs_mech);
1773         LASSERT(!atomic_read(&sec->ps_refcount));
1774         LASSERT(!atomic_read(&sec->ps_credcount));
1775 #ifdef __KERNEL__
1776         pipepath_len = strlen(gsec->gs_pipepath) + 1;
1777         rpc_unlink(gsec->gs_pipepath);
1778         pos = strrchr(gsec->gs_pipepath, '/');
1779         LASSERT(pos);
1780         *pos = 0;
1781         rpc_rmdir(gsec->gs_pipepath);
1782         OBD_FREE(gsec->gs_pipepath, pipepath_len);
1783 #endif
1784
1785         kgss_mech_put(gsec->gs_mech);
1786         OBD_FREE(gsec, sizeof(*gsec));
1787         EXIT;
1788 }
1789
1790 static
1791 struct ptlrpc_cred * gss_create_cred(struct ptlrpc_sec *sec,
1792                                      struct vfs_cred *vcred)
1793 {
1794         struct gss_cred *gcred;
1795         struct ptlrpc_cred *cred;
1796         ENTRY;
1797
1798         OBD_ALLOC(gcred, sizeof(*gcred));
1799         if (!gcred)
1800                 RETURN(NULL);
1801
1802         cred = &gcred->gc_base;
1803         INIT_LIST_HEAD(&cred->pc_hash);
1804         atomic_set(&cred->pc_refcount, 0);
1805         cred->pc_sec = sec;
1806         cred->pc_ops = &gss_credops;
1807         cred->pc_expire = 0;
1808         cred->pc_flags = 0;
1809         cred->pc_pag = vcred->vc_pag;
1810         cred->pc_uid = vcred->vc_uid;
1811         CDEBUG(D_SEC, "create a gss cred at %p("LPU64"/%u)\n",
1812                cred, vcred->vc_pag, vcred->vc_uid);
1813
1814         RETURN(cred);
1815 }
1816
1817 static int gss_estimate_payload(struct ptlrpc_sec *sec,
1818                                 struct ptlrpc_request *req,
1819                                 int msgsize)
1820 {
1821         switch (SEC_FLAVOR_SVC(req->rq_req_secflvr)) {
1822         case PTLRPCS_SVC_AUTH:
1823                 return GSS_MAX_AUTH_PAYLOAD;
1824         case PTLRPCS_SVC_PRIV:
1825                 return size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1826                                     GSS_PRIVBUF_PREFIX_LEN +
1827                                     GSS_PRIVBUF_SUFFIX_LEN);
1828         default:
1829                 LBUG();
1830                 return 0;
1831         }
1832 }
1833
1834 static int gss_alloc_reqbuf(struct ptlrpc_sec *sec,
1835                             struct ptlrpc_request *req,
1836                             int lmsg_size)
1837 {
1838         int msg_payload, sec_payload;
1839         int privacy, rc;
1840         ENTRY;
1841
1842         /* In PRIVACY mode, lustre message is always 0 (already encoded into
1843          * security payload).
1844          */
1845         privacy = (SEC_FLAVOR_SVC(req->rq_req_secflvr) == PTLRPCS_SVC_PRIV);
1846         msg_payload = privacy ? 0 : lmsg_size;
1847         sec_payload = gss_estimate_payload(sec, req, lmsg_size);
1848
1849         rc = sec_alloc_reqbuf(sec, req, msg_payload, sec_payload);
1850         if (rc)
1851                 return rc;
1852
1853         if (privacy) {
1854                 int buflen = lmsg_size + GSS_PRIVBUF_PREFIX_LEN +
1855                              GSS_PRIVBUF_SUFFIX_LEN;
1856                 char *buf;
1857
1858                 OBD_ALLOC(buf, buflen);
1859                 if (!buf) {
1860                         CERROR("Fail to alloc %d\n", buflen);
1861                         sec_free_reqbuf(sec, req);
1862                         RETURN(-ENOMEM);
1863                 }
1864                 req->rq_reqmsg = (struct lustre_msg *)
1865                                         (buf + GSS_PRIVBUF_PREFIX_LEN);
1866         }
1867
1868         RETURN(0);
1869 }
1870
1871 static void gss_free_reqbuf(struct ptlrpc_sec *sec,
1872                             struct ptlrpc_request *req)
1873 {
1874         char *buf;
1875         int privacy;
1876         ENTRY;
1877
1878         LASSERT(req->rq_reqmsg);
1879         LASSERT(req->rq_reqlen);
1880
1881         privacy = SEC_FLAVOR_SVC(req->rq_req_secflvr) == PTLRPCS_SVC_PRIV;
1882         if (privacy) {
1883                 buf = (char *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1884                 LASSERT(buf < req->rq_reqbuf ||
1885                         buf >= req->rq_reqbuf + req->rq_reqbuf_len);
1886                 OBD_FREE(buf, req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1887                               GSS_PRIVBUF_SUFFIX_LEN);
1888                 req->rq_reqmsg = NULL;
1889         }
1890
1891         sec_free_reqbuf(sec, req);
1892 }
1893
1894 static struct ptlrpc_secops gss_secops = {
1895         .create_sec             = gss_create_sec,
1896         .destroy_sec            = gss_destroy_sec,
1897         .create_cred            = gss_create_cred,
1898         .est_req_payload        = gss_estimate_payload,
1899         .est_rep_payload        = gss_estimate_payload,
1900         .alloc_reqbuf           = gss_alloc_reqbuf,
1901         .free_reqbuf            = gss_free_reqbuf,
1902 };
1903
1904 static struct ptlrpc_sec_type gss_type = {
1905         .pst_owner      = THIS_MODULE,
1906         .pst_name       = "sec.gss",
1907         .pst_inst       = ATOMIC_INIT(0),
1908         .pst_flavor     = PTLRPCS_FLVR_MAJOR_GSS,
1909         .pst_ops        = &gss_secops,
1910 };
1911
1912 extern int
1913 (*lustre_secinit_downcall_handler)(char *buffer, unsigned long count);
1914
1915 int __init ptlrpcs_gss_init(void)
1916 {
1917         int rc;
1918
1919         rc = ptlrpcs_register(&gss_type);
1920         if (rc)
1921                 return rc;
1922
1923 #ifdef __KERNEL__
1924         gss_svc_init();
1925
1926         rc = PTR_ERR(rpc_mkdir(LUSTRE_PIPEDIR, NULL));
1927         if (IS_ERR((void *)rc) && rc != -EEXIST) {
1928                 CERROR("fail to make rpcpipedir for lustre\n");
1929                 gss_svc_exit();
1930                 ptlrpcs_unregister(&gss_type);
1931                 return -1;
1932         }
1933         rc = 0;
1934 #else
1935 #endif
1936         rc = init_kerberos_module();
1937         if (rc) {
1938                 ptlrpcs_unregister(&gss_type);
1939         }
1940
1941         lustre_secinit_downcall_handler = gss_send_secinit_rpc;
1942
1943         return rc;
1944 }
1945
1946 #ifdef __KERNEL__
1947 static void __exit ptlrpcs_gss_exit(void)
1948 {
1949         lustre_secinit_downcall_handler = NULL;
1950
1951         cleanup_kerberos_module();
1952         rpc_rmdir(LUSTRE_PIPEDIR);
1953         gss_svc_exit();
1954         ptlrpcs_unregister(&gss_type);
1955 }
1956 #endif
1957
1958 MODULE_AUTHOR("Cluster File Systems, Inc. <info@clusterfs.com>");
1959 MODULE_DESCRIPTION("GSS Security module for Lustre");
1960 MODULE_LICENSE("GPL");
1961
1962 module_init(ptlrpcs_gss_init);
1963 module_exit(ptlrpcs_gss_exit);