Whamcloud - gitweb
current branches now use lnet from HEAD
[fs/lustre-release.git] / lustre / sec / gss / sec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * linux/net/sunrpc/auth_gss.c
12  *
13  * RPCSEC_GSS client authentication.
14  *
15  *  Copyright (c) 2000 The Regents of the University of Michigan.
16  *  All rights reserved.
17  *
18  *  Dug Song       <dugsong@monkey.org>
19  *  Andy Adamson   <andros@umich.edu>
20  *
21  *  Redistribution and use in source and binary forms, with or without
22  *  modification, are permitted provided that the following conditions
23  *  are met:
24  *
25  *  1. Redistributions of source code must retain the above copyright
26  *     notice, this list of conditions and the following disclaimer.
27  *  2. Redistributions in binary form must reproduce the above copyright
28  *     notice, this list of conditions and the following disclaimer in the
29  *     documentation and/or other materials provided with the distribution.
30  *  3. Neither the name of the University nor the names of its
31  *     contributors may be used to endorse or promote products derived
32  *     from this software without specific prior written permission.
33  *
34  *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
35  *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
36  *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37  *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
38  *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
39  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
40  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
41  *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
42  *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
43  *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
44  *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45  *
46  */
47
48 #ifndef EXPORT_SYMTAB
49 # define EXPORT_SYMTAB
50 #endif
51 #define DEBUG_SUBSYSTEM S_SEC
52 #ifdef __KERNEL__
53 #include <linux/init.h>
54 #include <linux/module.h>
55 #include <linux/slab.h>
56 #include <linux/dcache.h>
57 #include <linux/fs.h>
58 #include <linux/random.h>
59 /* for rpc_pipefs */
60 struct rpc_clnt;
61 #include <linux/sunrpc/rpc_pipe_fs.h>
62 #else
63 #include <liblustre.h>
64 #endif
65
66 #include <libcfs/kp30.h>
67 #include <linux/obd.h>
68 #include <linux/obd_class.h>
69 #include <linux/obd_support.h>
70 #include <linux/lustre_idl.h>
71 #include <linux/lustre_net.h>
72 #include <linux/lustre_import.h>
73 #include <linux/lustre_sec.h>
74
75 #include "gss_err.h"
76 #include "gss_internal.h"
77 #include "gss_api.h"
78
79 #define LUSTRE_PIPEDIR          "/lustre"
80
81 #define GSS_CREDCACHE_EXPIRE    (30 * 60)          /* 30 minute */
82
83 /**********************************************
84  * gss security init/fini helper              *
85  **********************************************/
86
87 static int secinit_compose_request(struct obd_import *imp,
88                                    char *buf, int bufsize,
89                                    int lustre_srv,
90                                    uid_t uid, gid_t gid,
91                                    long token_size,
92                                    char __user *token)
93 {
94         struct ptlrpcs_wire_hdr *hdr;
95         struct lustre_msg       *lmsg;
96         struct mds_req_sec_desc *secdesc;
97         int                      size = sizeof(*secdesc);
98         __u32                    lmsg_size, *p;
99         int                      rc;
100
101         lmsg_size = lustre_msg_size(1, &size);
102
103         if (sizeof(*hdr) + lmsg_size + size_round(token_size) > bufsize) {
104                 CERROR("token size %ld too large\n", token_size);
105                 return -EINVAL;
106         }
107
108         /* security wire hdr */
109         hdr = buf_to_sec_hdr(buf);
110         hdr->flavor  = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
111         hdr->msg_len = cpu_to_le32(lmsg_size);
112         hdr->sec_len = cpu_to_le32(8 * 4 + token_size);
113
114         /* lustre message & secdesc */
115         lmsg = buf_to_lustre_msg(buf);
116
117         lustre_init_msg(lmsg, 1, &size, NULL);
118         secdesc = lustre_msg_buf(lmsg, 0, size);
119         secdesc->rsd_uid = secdesc->rsd_fsuid = uid;
120         secdesc->rsd_gid = secdesc->rsd_fsgid = gid;
121         secdesc->rsd_cap = secdesc->rsd_ngroups = 0;
122
123         lmsg->handle   = imp->imp_remote_handle;
124         lmsg->type     = PTL_RPC_MSG_REQUEST;
125         lmsg->opc      = SEC_INIT;
126         lmsg->flags    = 0;
127         lmsg->conn_cnt = imp->imp_conn_cnt;
128
129         p = (__u32 *) (buf + sizeof(*hdr) + lmsg_size);
130
131         /* gss hdr */
132         *p++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);     /* gss version */
133         *p++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);         /* subflavor */
134         *p++ = cpu_to_le32(PTLRPCS_GSS_PROC_INIT);      /* proc */
135         *p++ = cpu_to_le32(0);                          /* seq */
136         *p++ = cpu_to_le32(PTLRPCS_GSS_SVC_NONE);       /* service */
137         *p++ = cpu_to_le32(0);                          /* context handle */
138
139         /* plus lustre svc type */
140         *p++ = cpu_to_le32(lustre_srv);
141
142         /* now the token part */
143         *p++ = cpu_to_le32((__u32) token_size);
144         LASSERT(((char *)p - buf) + token_size <= bufsize);
145
146         rc = copy_from_user(p, token, token_size);
147         if (rc) {
148                 CERROR("can't copy token\n");
149                 return -EFAULT;
150         }
151
152         rc = size_round(((char *)p - buf) + token_size);
153         return rc;
154 }
155
156 static int secinit_parse_reply(char *repbuf, int replen,
157                                char __user *outbuf, long outlen)
158 {
159         __u32                   *p = (__u32 *)repbuf;
160         struct ptlrpcs_wire_hdr *hdr = (struct ptlrpcs_wire_hdr *) repbuf;
161         __u32                    lmsg_len, sec_len, status;
162         __u32                    major, minor, seq, obj_len, round_len;
163         __u32                    effective = 0;
164
165         if (replen <= (4 + 6) * 4) {
166                 CERROR("reply size %d too small\n", replen);
167                 return -EINVAL;
168         }
169
170         hdr->flavor = le32_to_cpu(hdr->flavor);
171         hdr->msg_len = le32_to_cpu(hdr->msg_len);
172         hdr->sec_len = le32_to_cpu(hdr->sec_len);
173
174         lmsg_len = le32_to_cpu(p[2]);
175         sec_len = le32_to_cpu(p[3]);
176
177         /* sanity checks */
178         if (hdr->flavor != PTLRPCS_FLVR_GSS_NONE) {
179                 CERROR("unexpected reply\n");
180                 return -EINVAL;
181         }
182         if (hdr->msg_len % 8 ||
183             sizeof(*hdr) + hdr->msg_len + hdr->sec_len > replen) {
184                 CERROR("unexpected reply\n");
185                 return -EINVAL;
186         }
187         if (hdr->sec_len > outlen) {
188                 CERROR("outbuf too small\n");
189                 return -EINVAL;
190         }
191
192         p = (__u32 *) buf_to_sec_data(repbuf);
193         effective = 0;
194
195         p += 2; /* skip the leading unused bytes */
196         seq = le32_to_cpu(*p++);
197         major = le32_to_cpu(*p++);
198         minor = le32_to_cpu(*p++);
199         status = 0;
200
201         effective += 4 * 4;
202
203         if (copy_to_user(outbuf, &status, 4))
204                 return -EFAULT;
205         outbuf += 4;
206         if (copy_to_user(outbuf, &major, 4))
207                 return -EFAULT;
208         outbuf += 4;
209         if (copy_to_user(outbuf, &minor, 4))
210                 return -EFAULT;
211         outbuf += 4;
212         if (copy_to_user(outbuf, &seq, 4))
213                 return -EFAULT;
214         outbuf += 4;
215
216         obj_len = le32_to_cpu(*p++);
217         round_len = (obj_len + 3) & ~ 3;
218         if (copy_to_user(outbuf, &obj_len, 4))
219                 return -EFAULT;
220         outbuf += 4;
221         if (copy_to_user(outbuf, (char *)p, round_len))
222                 return -EFAULT;
223         p += round_len / 4;
224         outbuf += round_len;
225         effective += 4 + round_len;
226
227         obj_len = le32_to_cpu(*p++);
228         round_len = (obj_len + 3) & ~ 3;
229         if (copy_to_user(outbuf, &obj_len, 4))
230                 return -EFAULT;
231         outbuf += 4;
232         if (copy_to_user(outbuf, (char *)p, round_len))
233                 return -EFAULT;
234         p += round_len / 4;
235         outbuf += round_len;
236         effective += 4 + round_len;
237
238         return effective;
239 }
240
241 /* XXX move to where lgssd could see */
242 struct lgssd_ioctl_param {
243         int             version;        /* in   */
244         char           *uuid;           /* in   */
245         int             lustre_svc;     /* in   */
246         uid_t           uid;            /* in   */
247         gid_t           gid;            /* in   */
248         long            send_token_size;/* in   */
249         char           *send_token;     /* in   */
250         long            reply_buf_size; /* in   */
251         char           *reply_buf;      /* in   */
252         long            status;         /* out  */
253         long            reply_length;   /* out  */
254 };
255
256 static int gss_send_secinit_rpc(__user char *buffer, unsigned long count)
257 {
258         struct obd_import        *imp;
259         struct ptlrpc_request    *request = NULL;
260         struct lgssd_ioctl_param  param;
261         const int                 reqbuf_size = 1024;
262         const int                 repbuf_size = 1024;
263         char                     *reqbuf, *repbuf;
264         struct obd_device        *obd;
265         char                      obdname[64];
266         long                      lsize;
267         int                       rc, reqlen, replen;
268
269         if (count != sizeof(param)) {
270                 CERROR("ioctl size %lu, expect %d, please check lgssd version\n",
271                         count, sizeof(param));
272                 RETURN(-EINVAL);
273         }
274         if (copy_from_user(&param, buffer, sizeof(param))) {
275                 CERROR("failed copy data from lgssd\n");
276                 RETURN(-EFAULT);
277         }
278
279         if (param.version != GSSD_INTERFACE_VERSION) {
280                 CERROR("gssd interface version %d (expect %d)\n",
281                         param.version, GSSD_INTERFACE_VERSION);
282                 RETURN(-EINVAL);
283         }
284
285         /* take name */
286         if (strncpy_from_user(obdname, param.uuid,
287                               sizeof(obdname)) <= 0) {
288                 CERROR("Invalid obdname pointer\n");
289                 RETURN(-EFAULT);
290         }
291
292         obd = class_name2obd(obdname);
293         if (!obd) {
294                 CERROR("no such obd %s\n", obdname);
295                 RETURN(-EINVAL);
296         }
297
298         imp = class_import_get(obd->u.cli.cl_import);
299
300         OBD_ALLOC(reqbuf, reqbuf_size);
301         OBD_ALLOC(repbuf, reqbuf_size);
302
303         if (!reqbuf || !repbuf) {
304                 CERROR("Can't alloc buffer: %p/%p\n", reqbuf, repbuf);
305                 param.status = -ENOMEM;
306                 goto out_copy;
307         }
308
309         /* get token */
310         reqlen = secinit_compose_request(imp, reqbuf, reqbuf_size,
311                                          param.lustre_svc,
312                                          param.uid, param.gid,
313                                          param.send_token_size,
314                                          param.send_token);
315         if (reqlen < 0) {
316                 param.status = reqlen;
317                 goto out_copy;
318         }
319
320         request = ptl_do_rawrpc(imp, reqbuf, reqbuf_size, reqlen,
321                                 repbuf, repbuf_size, &replen,
322                                 SECINIT_RPC_TIMEOUT, &rc);
323         if (request == NULL || rc) {
324                 param.status = rc;
325                 goto out_copy;
326         }
327
328         if (replen > param.reply_buf_size) {
329                 CERROR("output buffer size %ld too small, need %d\n",
330                         param.reply_buf_size, replen);
331                 param.status = -EINVAL;
332                 goto out_copy;
333         }
334
335         lsize = secinit_parse_reply(repbuf, replen,
336                                     param.reply_buf, param.reply_buf_size);
337         if (lsize < 0) {
338                 param.status = (int) lsize;
339                 goto out_copy;
340         }
341
342         param.status = 0;
343         param.reply_length = lsize;
344
345 out_copy:
346         if (copy_to_user(buffer, &param, sizeof(param)))
347                 rc = -EFAULT;
348         else
349                 rc = 0;
350
351         class_import_put(imp);
352         if (request == NULL) {
353                 if (repbuf)
354                         OBD_FREE(repbuf, repbuf_size);
355                 if (reqbuf)
356                         OBD_FREE(reqbuf, reqbuf_size);
357         } else {
358                 rawrpc_req_finished(request);
359         }
360         RETURN(rc);
361 }
362
363 /**********************************************
364  * structure definitions                      *
365  **********************************************/
366 struct gss_sec {
367         struct ptlrpc_sec       gs_base;
368         struct gss_api_mech    *gs_mech;
369         spinlock_t              gs_lock;
370         struct list_head        gs_upcalls;
371         char                   *gs_pipepath;
372         struct dentry          *gs_depipe;
373 };
374
375 struct gss_upcall_msg_data {
376         __u64                           gum_pag;
377         __u32                           gum_uid;
378         __u32                           gum_svc;
379         __u32                           gum_nal;
380         __u32                           gum_netid;
381         __u64                           gum_nid;
382 };
383
384 struct gss_upcall_msg {
385         struct rpc_pipe_msg             gum_base;
386         atomic_t                        gum_refcount;
387         struct list_head                gum_list;
388         struct gss_sec                 *gum_gsec;
389         wait_queue_head_t               gum_waitq;
390         char                            gum_obdname[64];
391         struct gss_upcall_msg_data      gum_data;
392 };
393
394 #ifdef __KERNEL__
395 static rwlock_t gss_ctx_lock = RW_LOCK_UNLOCKED;
396 /**********************************************
397  * rpc_pipe upcall helpers                    *
398  **********************************************/
399 static
400 void gss_release_msg(struct gss_upcall_msg *gmsg)
401 {
402         ENTRY;
403         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
404
405         if (!atomic_dec_and_test(&gmsg->gum_refcount)) {
406                 CDEBUG(D_SEC, "gmsg %p ref %d\n", gmsg,
407                        atomic_read(&gmsg->gum_refcount));
408                 EXIT;
409                 return;
410         }
411         LASSERT(list_empty(&gmsg->gum_list));
412 #if 0
413         LASSERT(list_empty(&gmsg->gum_base.list));
414 #else
415         /* XXX */
416         if (!list_empty(&gmsg->gum_base.list)) {
417                 int error = gmsg->gum_base.errno;
418                 
419                 CWARN("msg %p: list: %p/%p/%p, copied %d, err %d, wq %d\n",
420                       gmsg, &gmsg->gum_base.list, gmsg->gum_base.list.prev,
421                       gmsg->gum_base.list.next, gmsg->gum_base.copied, error,
422                       list_empty(&gmsg->gum_waitq.task_list));
423                 LBUG();
424         }
425 #endif
426         OBD_FREE(gmsg, sizeof(*gmsg));
427         EXIT;
428 }
429
430 static void
431 gss_unhash_msg_nolock(struct gss_upcall_msg *gmsg)
432 {
433         LASSERT_SPIN_LOCKED(&gmsg->gum_gsec->gs_lock);
434
435         if (list_empty(&gmsg->gum_list))
436                 return;
437
438         list_del_init(&gmsg->gum_list);
439         wake_up(&gmsg->gum_waitq);
440         LASSERT(atomic_read(&gmsg->gum_refcount) > 1);
441         atomic_dec(&gmsg->gum_refcount);
442 }
443
444 static void
445 gss_unhash_msg(struct gss_upcall_msg *gmsg)
446 {
447         struct gss_sec *gsec = gmsg->gum_gsec;
448
449         spin_lock(&gsec->gs_lock);
450         gss_unhash_msg_nolock(gmsg);
451         spin_unlock(&gsec->gs_lock);
452 }
453
454 static
455 struct gss_upcall_msg * gss_find_upcall(struct gss_sec *gsec,
456                                         char *obdname,
457                                         struct gss_upcall_msg_data *gmd)
458 {
459         struct gss_upcall_msg *gmsg;
460         ENTRY;
461
462         LASSERT_SPIN_LOCKED(&gsec->gs_lock);
463
464         list_for_each_entry(gmsg, &gsec->gs_upcalls, gum_list) {
465                 if (memcmp(&gmsg->gum_data, gmd, sizeof(*gmd)))
466                         continue;
467                 if (strcmp(gmsg->gum_obdname, obdname))
468                         continue;
469                 LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
470                 atomic_inc(&gmsg->gum_refcount);
471                 CDEBUG(D_SEC, "found gmsg at %p: obdname %s, uid %d, ref %d\n",
472                        gmsg, obdname, gmd->gum_uid,
473                        atomic_read(&gmsg->gum_refcount));
474                 RETURN(gmsg);
475         }
476         RETURN(NULL);
477 }
478
479 static void gss_init_upcall_msg(struct gss_upcall_msg *gmsg,
480                                 struct gss_sec *gsec, char *obdname,
481                                 struct gss_upcall_msg_data *gmd)
482 {
483         struct rpc_pipe_msg *rpcmsg;
484         ENTRY;
485
486         /* 2 refs: 1 for hash, 1 for current user */
487         init_waitqueue_head(&gmsg->gum_waitq);
488         list_add(&gmsg->gum_list, &gsec->gs_upcalls);
489         atomic_set(&gmsg->gum_refcount, 2);
490         gmsg->gum_gsec = gsec;
491         strncpy(gmsg->gum_obdname, obdname, sizeof(gmsg->gum_obdname));
492         memcpy(&gmsg->gum_data, gmd, sizeof(*gmd));
493
494         rpcmsg = &gmsg->gum_base;
495         INIT_LIST_HEAD(&rpcmsg->list);
496         rpcmsg->data = &gmsg->gum_data;
497         rpcmsg->len = sizeof(gmsg->gum_data);
498         rpcmsg->copied = 0;
499         rpcmsg->errno = 0;
500         EXIT;
501 }
502 #endif /* __KERNEL__ */
503
504 /* this seems to be used only from userspace code */
505 #ifndef __KERNEL__
506 /********************************************
507  * gss cred manipulation helpers            *
508  ********************************************/
509 static
510 int gss_cred_is_uptodate_ctx(struct ptlrpc_cred *cred)
511 {
512         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
513         int res = 0;
514
515         read_lock(&gss_ctx_lock);
516         if (((cred->pc_flags & PTLRPC_CRED_FLAGS_MASK) ==
517              PTLRPC_CRED_UPTODATE) &&
518             gcred->gc_ctx)
519                 res = 1;
520         read_unlock(&gss_ctx_lock);
521         return res;
522 }
523 #endif
524
525 static inline
526 struct gss_cl_ctx *gss_get_ctx(struct gss_cl_ctx *ctx)
527 {
528         atomic_inc(&ctx->gc_refcount);
529         return ctx;
530 }
531
532 static
533 void gss_destroy_ctx(struct gss_cl_ctx *ctx)
534 {
535         ENTRY;
536
537         CDEBUG(D_SEC, "destroy cl_ctx %p\n", ctx);
538         if (ctx->gc_gss_ctx)
539                 kgss_delete_sec_context(&ctx->gc_gss_ctx);
540
541         if (ctx->gc_wire_ctx.len > 0) {
542                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
543                 ctx->gc_wire_ctx.len = 0;
544         }
545
546         OBD_FREE(ctx, sizeof(*ctx));
547 }
548
549 static
550 void gss_put_ctx(struct gss_cl_ctx *ctx)
551 {
552         if (atomic_dec_and_test(&ctx->gc_refcount))
553                 gss_destroy_ctx(ctx);
554 }
555
556 static
557 struct gss_cl_ctx *gss_cred_get_ctx(struct ptlrpc_cred *cred)
558 {
559         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
560         struct gss_cl_ctx *ctx = NULL;
561
562         read_lock(&gss_ctx_lock);
563         if (gcred->gc_ctx)
564                 ctx = gss_get_ctx(gcred->gc_ctx);
565         read_unlock(&gss_ctx_lock);
566         return ctx;
567 }
568
569 static
570 void gss_cred_set_ctx(struct ptlrpc_cred *cred, struct gss_cl_ctx *ctx)
571 {
572         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
573         struct gss_cl_ctx *old;
574         __u64 ctx_expiry;
575         ENTRY;
576
577         if (kgss_inquire_context(ctx->gc_gss_ctx, &ctx_expiry)) {
578                 CERROR("unable to get expire time\n");
579                 ctx_expiry = 1; /* make it expired now */
580         }
581         cred->pc_expire = gss_roundup_expire_time(ctx_expiry);
582
583         write_lock(&gss_ctx_lock);
584         old = gcred->gc_ctx;
585         gcred->gc_ctx = ctx;
586         set_bit(PTLRPC_CRED_UPTODATE_BIT, &cred->pc_flags);
587         write_unlock(&gss_ctx_lock);
588         if (old)
589                 gss_put_ctx(old);
590
591         CDEBUG(D_SEC, "client refreshed gss cred %p(uid %u)\n",
592                cred, cred->pc_uid);
593         EXIT;
594 }
595
596 static int
597 simple_get_bytes(char **buf, __u32 *buflen, void *res, __u32 reslen)
598 {
599         if (*buflen < reslen) {
600                 CERROR("buflen %u < %u\n", *buflen, reslen);
601                 return -EINVAL;
602         }
603
604         memcpy(res, *buf, reslen);
605         *buf += reslen;
606         *buflen -= reslen;
607         return 0;
608 }
609
610 /* data passed down:
611  *  - uid
612  *  - timeout
613  *  - gc_win / error
614  *  - wire_ctx (rawobj)
615  *  - mech_ctx? (rawobj)
616  */
617 static
618 int gss_parse_init_downcall(struct gss_api_mech *gm, rawobj_t *buf,
619                             struct gss_cl_ctx **gc,
620                             struct gss_upcall_msg_data *gmd, int *gss_err)
621 {
622         char *p = (char *)buf->data;
623         struct gss_cl_ctx *ctx;
624         __u32 len = buf->len;
625         unsigned int timeout;
626         rawobj_t tmp_buf;
627         int err = -EPERM;
628         ENTRY;
629
630         *gc = NULL;
631         *gss_err = 0;
632
633         OBD_ALLOC(ctx, sizeof(*ctx));
634         if (!ctx)
635                 RETURN(-ENOMEM);
636
637         ctx->gc_proc = RPC_GSS_PROC_DATA;
638         ctx->gc_seq = 0;
639         spin_lock_init(&ctx->gc_seq_lock);
640         atomic_set(&ctx->gc_refcount,1);
641
642         if (simple_get_bytes(&p, &len, &gmd->gum_pag, sizeof(gmd->gum_pag)))
643                 goto err_free_ctx;
644         if (simple_get_bytes(&p, &len, &gmd->gum_uid, sizeof(gmd->gum_uid)))
645                 goto err_free_ctx;
646         if (simple_get_bytes(&p, &len, &gmd->gum_svc, sizeof(gmd->gum_svc)))
647                 goto err_free_ctx;
648         if (simple_get_bytes(&p, &len, &gmd->gum_nal, sizeof(gmd->gum_nal)))
649                 goto err_free_ctx;
650         if (simple_get_bytes(&p, &len, &gmd->gum_netid, sizeof(gmd->gum_netid)))
651                 goto err_free_ctx;
652         if (simple_get_bytes(&p, &len, &gmd->gum_nid, sizeof(gmd->gum_nid)))
653                 goto err_free_ctx;
654         /* FIXME: discarded timeout for now */
655         if (simple_get_bytes(&p, &len, &timeout, sizeof(timeout)))
656                 goto err_free_ctx;
657         if (simple_get_bytes(&p, &len, &ctx->gc_win, sizeof(ctx->gc_win)))
658                 goto err_free_ctx;
659
660         /* lgssd signals an error by passing ctx->gc_win = 0: */
661         if (!ctx->gc_win) {
662                 /* in which case the next 2 int are:
663                  * - rpc error
664                  * - gss error
665                  */
666                 if (simple_get_bytes(&p, &len, &err, sizeof(err))) {
667                         err = -EPERM;
668                         goto err_free_ctx;
669                 }
670                 if (simple_get_bytes(&p, &len, gss_err, sizeof(*gss_err))) {
671                         err = -EPERM;
672                         goto err_free_ctx;
673                 }
674                 if (err == 0 && *gss_err == 0) {
675                         CERROR("no error passed from downcall\n");
676                         err = -EPERM;
677                 }
678                 goto err_free_ctx;
679         }
680
681         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
682                 goto err_free_ctx;
683         if (rawobj_dup(&ctx->gc_wire_ctx, &tmp_buf)) {
684                 err = -ENOMEM;
685                 goto err_free_ctx;
686         }
687         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
688                 goto err_free_wire_ctx;
689         if (len) {
690                 CERROR("unexpected trailing %u bytes\n", len);
691                 goto err_free_wire_ctx;
692         }
693         if (kgss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx))
694                 goto err_free_wire_ctx;
695
696         *gc = ctx;
697         RETURN(0);
698
699 err_free_wire_ctx:
700         if (ctx->gc_wire_ctx.data)
701                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
702 err_free_ctx:
703         OBD_FREE(ctx, sizeof(*ctx));
704         CDEBUG(D_SEC, "err_code %d, gss code %d\n", err, *gss_err);
705         return err;
706 }
707
708 /***************************************
709  * cred APIs                           *
710  ***************************************/
711 #ifdef __KERNEL__
712 static int gss_cred_refresh(struct ptlrpc_cred *cred)
713 {
714         struct obd_import          *import;
715         struct gss_sec             *gsec;
716         struct gss_upcall_msg      *gss_msg, *gss_new;
717         struct gss_upcall_msg_data  gmd;
718         struct dentry              *dentry;
719         char                       *obdname, *obdtype;
720         wait_queue_t                wait;
721         int                         res;
722         ENTRY;
723
724         might_sleep();
725
726         /* any flags means it has been handled, do nothing */
727         if (cred->pc_flags & PTLRPC_CRED_FLAGS_MASK)
728                 RETURN(0);
729
730         LASSERT(cred->pc_sec);
731         LASSERT(cred->pc_sec->ps_import);
732         LASSERT(cred->pc_sec->ps_import->imp_obd);
733
734         import = cred->pc_sec->ps_import;
735         if (!import->imp_connection) {
736                 CERROR("import has no connection set\n");
737                 RETURN(-EINVAL);
738         }
739
740         gmd.gum_pag = cred->pc_pag;
741         gmd.gum_uid = cred->pc_uid;
742         gmd.gum_nal = import->imp_connection->c_peer.peer_ni->pni_number;
743         gmd.gum_netid = 0;
744         gmd.gum_nid = import->imp_connection->c_peer.peer_id.nid;
745
746         obdtype = import->imp_obd->obd_type->typ_name;
747         if (!strcmp(obdtype, OBD_MDC_DEVICENAME))
748                 gmd.gum_svc = LUSTRE_GSS_SVC_MDS;
749         else if (!strcmp(obdtype, OBD_OSC_DEVICENAME))
750                 gmd.gum_svc = LUSTRE_GSS_SVC_OSS;
751         else {
752                 CERROR("gss on %s?\n", obdtype);
753                 RETURN(-EINVAL);
754         }
755
756         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
757         obdname = import->imp_obd->obd_name;
758         dentry = gsec->gs_depipe;
759         gss_new = NULL;
760         res = 0;
761
762         CDEBUG(D_SEC, "Initiate gss context %p(%u@%s)\n",
763                container_of(cred, struct gss_cred, gc_base),
764                cred->pc_uid, import->imp_target_uuid.uuid);
765
766 again:
767         spin_lock(&gsec->gs_lock);
768         gss_msg = gss_find_upcall(gsec, obdname, &gmd);
769         if (gss_msg) {
770                 if (gss_new) {
771                         OBD_FREE(gss_new, sizeof(*gss_new));
772                         gss_new = NULL;
773                 }
774                 GOTO(waiting, res);
775         }
776
777         if (!gss_new) {
778                 spin_unlock(&gsec->gs_lock);
779                 OBD_ALLOC(gss_new, sizeof(*gss_new));
780                 if (!gss_new)
781                         RETURN(-ENOMEM);
782                 goto again;
783         }
784         /* so far we'v created gss_new */
785         gss_init_upcall_msg(gss_new, gsec, obdname, &gmd);
786
787         /* we'v created upcall msg, nobody else should touch the
788          * flag of this cred, unless be set as dead/expire by
789          * administrator via lctl etc.
790          */
791         if (cred->pc_flags & PTLRPC_CRED_FLAGS_MASK) {
792                 CWARN("cred %p("LPU64"/%u) was set flags %lx unexpectedly\n",
793                       cred, cred->pc_pag, cred->pc_uid, cred->pc_flags);
794                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
795                 gss_unhash_msg_nolock(gss_new);
796                 spin_unlock(&gsec->gs_lock);
797                 gss_release_msg(gss_new);
798                 RETURN(0);
799         }
800
801         /* need to make upcall now */
802         spin_unlock(&gsec->gs_lock);
803         res = rpc_queue_upcall(dentry->d_inode, &gss_new->gum_base);
804         if (res) {
805                 CERROR("rpc_queue_upcall failed: %d\n", res);
806                 gss_unhash_msg(gss_new);
807                 gss_release_msg(gss_new);
808                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
809                 RETURN(res);
810         }
811         gss_msg = gss_new;
812         spin_lock(&gsec->gs_lock);
813
814 waiting:
815         /* upcall might finish quickly */
816         if (list_empty(&gss_msg->gum_list)) {
817                 spin_unlock(&gsec->gs_lock);
818                 res = 0;
819                 goto out;
820         }
821
822         init_waitqueue_entry(&wait, current);
823         set_current_state(TASK_INTERRUPTIBLE);
824         add_wait_queue(&gss_msg->gum_waitq, &wait);
825         spin_unlock(&gsec->gs_lock);
826
827         if (gss_new)
828                 res = schedule_timeout(CRED_REFRESH_UPCALL_TIMEOUT * HZ);
829         else {
830                 schedule();
831                 res = 0;
832         }
833
834         remove_wait_queue(&gss_msg->gum_waitq, &wait);
835
836         /* - the one who refresh the cred for us should also be responsible
837          *   to set the status of cred, we can simply return.
838          * - if cred flags has been set, we also don't need to do that again,
839          *   no matter signal pending or timeout etc.
840          */
841         if (!gss_new || cred->pc_flags & PTLRPC_CRED_FLAGS_MASK)
842                 goto out;
843
844         if (signal_pending(current)) {
845                 CERROR("%s: cred %p: interrupted upcall\n",
846                        current->comm, cred);
847                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
848                 res = -EINTR;
849         } else if (res == 0) {
850                 CERROR("cred %p: upcall timedout\n", cred);
851                 set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
852                 res = -ETIMEDOUT;
853         } else
854                 res = 0;
855
856 out:
857         gss_release_msg(gss_msg);
858
859         RETURN(res);
860 }
861 #else /* !__KERNEL__ */
862 extern int lgss_handle_krb5_upcall(uid_t uid, __u32 dest_ip,
863                                    char *obd_name, char *buf, int bufsize,
864                                    int (*callback)(char*, unsigned long));
865
866 static int gss_cred_refresh(struct ptlrpc_cred *cred)
867 {
868         char                    buf[4096];
869         rawobj_t                obj;
870         struct obd_import      *imp;
871         struct gss_sec         *gsec;
872         struct gss_api_mech    *mech;
873         struct gss_cl_ctx      *ctx = NULL;
874         ptl_nid_t               peer_nid;
875         __u32                   dest_ip;
876         __u32                   subflavor;
877         int                     rc, gss_err;
878         struct gss_upcall_msg_data gmd = { 0 };
879
880         LASSERT(cred);
881         LASSERT(cred->pc_sec);
882         LASSERT(cred->pc_sec->ps_import);
883         LASSERT(cred->pc_sec->ps_import->imp_obd);
884
885         if (ptlrpcs_cred_is_uptodate(cred))
886                 RETURN(0);
887
888         imp = cred->pc_sec->ps_import;
889         peer_nid = imp->imp_connection->c_peer.peer_id.nid;
890         dest_ip = (__u32) (peer_nid & 0xFFFFFFFF);
891         subflavor = cred->pc_sec->ps_flavor;
892
893         if (subflavor != PTLRPCS_SUBFLVR_KRB5I) {
894                 CERROR("unknown subflavor %u\n", subflavor);
895                 GOTO(err_out, rc = -EINVAL);
896         }
897
898         rc = lgss_handle_krb5_upcall(cred->pc_uid, dest_ip,
899                                      imp->imp_obd->obd_name,
900                                      buf, sizeof(buf),
901                                      gss_send_secinit_rpc);
902         LASSERT(rc != 0);
903         if (rc < 0)
904                 goto err_out;
905
906         obj.data = buf;
907         obj.len = rc;
908
909         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
910         mech = gsec->gs_mech;
911         LASSERT(mech);
912
913         rc = gss_parse_init_downcall(mech, &obj, &ctx, &gmd,
914                                      &gss_err);
915         if (rc || gss_err) {
916                 CERROR("parse init downcall: rpc %d, gss 0x%x\n", rc, gss_err);
917                 if (rc != -ERESTART || gss_err != 0)
918                         set_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags);
919                 if (rc == 0)
920                         rc = -EPERM;
921                 goto err_out;
922         }
923
924         LASSERT(ctx);
925         gss_cred_set_ctx(cred, ctx);
926         LASSERT(gss_cred_is_uptodate_ctx(cred));
927
928         return 0;
929 err_out:
930         set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
931         return rc;
932 }
933 #endif
934
935 static int gss_cred_match(struct ptlrpc_cred *cred,
936                           struct vfs_cred *vcred)
937 {
938         RETURN(cred->pc_pag == vcred->vc_pag);
939 }
940
941 static int gss_cred_sign(struct ptlrpc_cred *cred,
942                          struct ptlrpc_request *req)
943 {
944         struct gss_cred         *gcred;
945         struct gss_cl_ctx       *ctx;
946         rawobj_t                 lmsg, mic;
947         __u32                   *vp, *vpsave, vlen, seclen;
948         __u32                    seqnum, major, rc = 0;
949         ENTRY;
950
951         LASSERT(req->rq_reqbuf);
952         LASSERT(req->rq_cred == cred);
953
954         gcred = container_of(cred, struct gss_cred, gc_base);
955         ctx = gss_cred_get_ctx(cred);
956         if (!ctx) {
957                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
958                         cred, cred->pc_pag, cred->pc_uid);
959                 RETURN(-EPERM);
960         }
961
962         lmsg.len = req->rq_reqlen;
963         lmsg.data = (__u8 *) req->rq_reqmsg;
964
965         vp = (__u32 *) (lmsg.data + lmsg.len);
966         vlen = req->rq_reqbuf_len - sizeof(struct ptlrpcs_wire_hdr) -
967                lmsg.len;
968         seclen = vlen;
969
970         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
971                 CERROR("vlen %d, need %d\n",
972                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
973                 rc = -EIO;
974                 goto out;
975         }
976
977         spin_lock(&ctx->gc_seq_lock);
978         seqnum = ctx->gc_seq++;
979         spin_unlock(&ctx->gc_seq_lock);
980
981         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
982         *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);        /* subflavor */
983         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
984         *vp++ = cpu_to_le32(seqnum);                    /* seq */
985         *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_INTEGRITY); /* service */
986         vlen -= 5 * 4;
987
988         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
989                 rc = -EIO;
990                 goto out;
991         }
992         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
993
994         vpsave = vp++;  /* reserve for size */
995         vlen -= 4;
996
997         mic.len = vlen;
998         mic.data = (unsigned char *)vp;
999
1000         CDEBUG(D_SEC, "reqbuf at %p, lmsg at %p, len %d, mic at %p, len %d\n",
1001                req->rq_reqbuf, lmsg.data, lmsg.len, mic.data, mic.len);
1002         major = kgss_get_mic(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT, &lmsg, &mic);
1003         if (major) {
1004                 CERROR("cred %p: req %p compute mic error, major %x\n",
1005                        cred, req, major);
1006                 rc = -EACCES;
1007                 goto out;
1008         }
1009
1010         *vpsave = cpu_to_le32(mic.len);
1011         
1012         seclen = seclen - vlen + mic.len;
1013         buf_to_sec_hdr(req->rq_reqbuf)->sec_len = cpu_to_le32(seclen);
1014         req->rq_reqdata_len += size_round(seclen);
1015         CDEBUG(D_SEC, "msg size %d, checksum size %d, total sec size %d\n",
1016                lmsg.len, mic.len, seclen);
1017 out:
1018         gss_put_ctx(ctx);
1019         RETURN(rc);
1020 }
1021
1022 static int gss_cred_verify(struct ptlrpc_cred *cred,
1023                            struct ptlrpc_request *req)
1024 {
1025         struct gss_cred        *gcred;
1026         struct gss_cl_ctx      *ctx;
1027         struct ptlrpcs_wire_hdr *sec_hdr;
1028         rawobj_t                lmsg, mic;
1029         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1030         __u32                   major, minor, rc;
1031         ENTRY;
1032
1033         LASSERT(req->rq_repbuf);
1034         LASSERT(req->rq_cred == cred);
1035
1036         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1037         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr) + sec_hdr->msg_len);
1038         vlen = sec_hdr->sec_len;
1039
1040         if (vlen < 7 * 4) {
1041                 CERROR("reply sec size %u too small\n", vlen);
1042                 RETURN(-EPROTO);
1043         }
1044
1045         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1046                 CERROR("reply have different gss version\n");
1047                 RETURN(-EPROTO);
1048         }
1049         subflavor = le32_to_cpu(*vp++);
1050         proc = le32_to_cpu(*vp++);
1051         vlen -= 3 * 4;
1052
1053         switch (proc) {
1054         case PTLRPCS_GSS_PROC_DATA:
1055                 seq = le32_to_cpu(*vp++);
1056                 svc = le32_to_cpu(*vp++);
1057                 if (svc != PTLRPCS_GSS_SVC_INTEGRITY) {
1058                         CERROR("Unknown svc %d\n", svc);
1059                         RETURN(-EPROTO);
1060                 }
1061                 if (*vp++ != 0) {
1062                         CERROR("Unexpected ctx handle\n");
1063                         RETURN(-EPROTO);
1064                 }
1065                 mic.len = le32_to_cpu(*vp++);
1066                 vlen -= 4 * 4;
1067                 if (vlen < mic.len) {
1068                         CERROR("vlen %d, mic.len %d\n", vlen, mic.len);
1069                         RETURN(-EINVAL);
1070                 }
1071                 mic.data = (unsigned char *)vp;
1072
1073                 gcred = container_of(cred, struct gss_cred, gc_base);
1074                 ctx = gss_cred_get_ctx(cred);
1075                 LASSERT(ctx);
1076
1077                 lmsg.len = sec_hdr->msg_len;
1078                 lmsg.data = (__u8 *) buf_to_lustre_msg(req->rq_repbuf);
1079
1080                 major = kgss_verify_mic(ctx->gc_gss_ctx, &lmsg, &mic, NULL);
1081                 if (major != GSS_S_COMPLETE) {
1082                         CERROR("cred %p: req %p verify mic error: major %x\n",
1083                                cred, req, major);
1084
1085                         if (major == GSS_S_CREDENTIALS_EXPIRED ||
1086                             major == GSS_S_CONTEXT_EXPIRED) {
1087                                 ptlrpcs_cred_expire(cred);
1088                                 req->rq_ptlrpcs_restart = 1;
1089                                 rc = 0;
1090                         } else
1091                                 rc = -EINVAL;
1092
1093                         GOTO(proc_data_out, rc);
1094                 }
1095
1096                 req->rq_repmsg = (struct lustre_msg *) lmsg.data;
1097                 req->rq_replen = lmsg.len;
1098
1099                 /* here we could check the seq number is the same one
1100                  * we sent to server. but portals has prevent us from
1101                  * replay attack, so maybe we don't need check it again.
1102                  */
1103                 rc = 0;
1104 proc_data_out:
1105                 gss_put_ctx(ctx);
1106                 break;
1107         case PTLRPCS_GSS_PROC_ERR:
1108                 major = le32_to_cpu(*vp++);
1109                 minor = le32_to_cpu(*vp++);
1110                 /* server return NO_CONTEXT might be caused by context expire
1111                  * or server reboot/failover. we refresh the cred transparently
1112                  * to upper layer.
1113                  * In some cases, our gss handle is possible to be incidentally
1114                  * identical to another handle since the handle itself is not
1115                  * fully random. In krb5 case, the GSS_S_BAD_SIG will be
1116                  * returned, maybe other gss error for other mechanism. Here we
1117                  * only consider krb5 mech (FIXME) and try to establish new
1118                  * context.
1119                  */
1120                 if (major == GSS_S_NO_CONTEXT ||
1121                     major == GSS_S_BAD_SIG) {
1122                         CWARN("req %p: server report cred %p %s\n",
1123                                req, cred, (major == GSS_S_NO_CONTEXT) ?
1124                                            "NO_CONTEXT" : "BAD_SIG");
1125
1126                         ptlrpcs_cred_expire(cred);
1127                         req->rq_ptlrpcs_restart = 1;
1128                         rc = 0;
1129                 } else {
1130                         CERROR("req %p: unrecognized gss error (%x/%x)\n",
1131                                 req, major, minor);
1132                         rc = -EACCES;
1133                 }
1134                 break;
1135         default:
1136                 CERROR("unknown gss proc %d\n", proc);
1137                 rc = -EPROTO;
1138         }
1139
1140         RETURN(rc);
1141 }
1142
1143 static int gss_cred_seal(struct ptlrpc_cred *cred,
1144                          struct ptlrpc_request *req)
1145 {
1146         struct gss_cred         *gcred;
1147         struct gss_cl_ctx       *ctx;
1148         struct ptlrpcs_wire_hdr *sec_hdr;
1149         rawobj_buf_t             msg_buf;
1150         rawobj_t                 cipher_buf;
1151         __u32                   *vp, *vpsave, vlen, seclen;
1152         __u32                    major, seqnum, rc = 0;
1153         ENTRY;
1154
1155         LASSERT(req->rq_reqbuf);
1156         LASSERT(req->rq_cred == cred);
1157
1158         gcred = container_of(cred, struct gss_cred, gc_base);
1159         ctx = gss_cred_get_ctx(cred);
1160         if (!ctx) {
1161                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
1162                         cred, cred->pc_pag, cred->pc_uid);
1163                 RETURN(-EPERM);
1164         }
1165
1166         vp = (__u32 *) (req->rq_reqbuf + sizeof(*sec_hdr));
1167         vlen = req->rq_reqbuf_len - sizeof(*sec_hdr);
1168         seclen = vlen;
1169
1170         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
1171                 CERROR("vlen %d, need %d\n",
1172                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
1173                 rc = -EIO;
1174                 goto out;
1175         }
1176
1177         spin_lock(&ctx->gc_seq_lock);
1178         seqnum = ctx->gc_seq++;
1179         spin_unlock(&ctx->gc_seq_lock);
1180
1181         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
1182         *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5P);        /* subflavor */
1183         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
1184         *vp++ = cpu_to_le32(seqnum);                    /* seq */
1185         *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_PRIVACY);   /* service */
1186         vlen -= 5 * 4;
1187
1188         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
1189                 rc = -EIO;
1190                 goto out;
1191         }
1192         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
1193
1194         vpsave = vp++;  /* reserve for size */
1195         vlen -= 4;
1196
1197         msg_buf.buf = (__u8 *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1198         msg_buf.buflen = req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1199                                           GSS_PRIVBUF_SUFFIX_LEN;
1200         msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1201         msg_buf.datalen = req->rq_reqlen;
1202
1203         cipher_buf.data = (__u8 *) vp;
1204         cipher_buf.len = vlen;
1205
1206         major = kgss_wrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1207                           &msg_buf, &cipher_buf);
1208         if (major) {
1209                 CERROR("cred %p: error wrap: major 0x%x\n", cred, major);
1210                 GOTO(out, rc = -EINVAL);
1211         }
1212
1213         *vpsave = cpu_to_le32(cipher_buf.len);
1214
1215         seclen = seclen - vlen + cipher_buf.len;
1216         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1217         sec_hdr->sec_len = cpu_to_le32(seclen);
1218         req->rq_reqdata_len += size_round(seclen);
1219
1220         CDEBUG(D_SEC, "msg size %d, total sec size %d\n",
1221                req->rq_reqlen, seclen);
1222 out:
1223         gss_put_ctx(ctx);
1224         RETURN(rc);
1225 }
1226
1227 static int gss_cred_unseal(struct ptlrpc_cred *cred,
1228                            struct ptlrpc_request *req)
1229 {
1230         struct gss_cred        *gcred;
1231         struct gss_cl_ctx      *ctx;
1232         struct ptlrpcs_wire_hdr *sec_hdr;
1233         rawobj_t                cipher_text, plain_text;
1234         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1235         __u32                   major, rc;
1236         ENTRY;
1237
1238         LASSERT(req->rq_repbuf);
1239         LASSERT(req->rq_cred == cred);
1240
1241         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1242         if (sec_hdr->msg_len != 0) {
1243                 CERROR("unexpected msg_len %u\n", sec_hdr->msg_len);
1244                 RETURN(-EPROTO);
1245         }
1246
1247         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr));
1248         vlen = sec_hdr->sec_len;
1249
1250         if (vlen < 7 * 4) {
1251                 CERROR("reply sec size %u too small\n", vlen);
1252                 RETURN(-EPROTO);
1253         }
1254
1255         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1256                 CERROR("reply have different gss version\n");
1257                 RETURN(-EPROTO);
1258         }
1259         subflavor = le32_to_cpu(*vp++);
1260         proc = le32_to_cpu(*vp++);
1261         seq = le32_to_cpu(*vp++);
1262         svc = le32_to_cpu(*vp++);
1263         vlen -= 5 * 4;
1264
1265         switch (proc) {
1266         case PTLRPCS_GSS_PROC_DATA:
1267                 if (svc != PTLRPCS_GSS_SVC_PRIVACY) {
1268                         CERROR("Unknown svc %d\n", svc);
1269                         RETURN(-EPROTO);
1270                 }
1271                 if (*vp++ != 0) {
1272                         CERROR("Unexpected ctx handle\n");
1273                         RETURN(-EPROTO);
1274                 }
1275                 vlen -= 4;
1276
1277                 cipher_text.len = le32_to_cpu(*vp++);
1278                 cipher_text.data = (__u8 *) vp;
1279                 vlen -= 4;
1280
1281                 if (vlen < cipher_text.len) {
1282                         CERROR("cipher text to be %u while buf only %u\n",
1283                                 cipher_text.len, vlen);
1284                         RETURN(-EPROTO);
1285                 }
1286
1287                 plain_text = cipher_text;
1288
1289                 gcred = container_of(cred, struct gss_cred, gc_base);
1290                 ctx = gss_cred_get_ctx(cred);
1291                 LASSERT(ctx);
1292
1293                 major = kgss_unwrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1294                                     &cipher_text, &plain_text);
1295                 if (major) {
1296                         CERROR("cred %p: error unwrap: major 0x%x\n",
1297                                cred, major);
1298
1299                         if (major == GSS_S_CREDENTIALS_EXPIRED ||
1300                             major == GSS_S_CONTEXT_EXPIRED) {
1301                                 ptlrpcs_cred_expire(cred);
1302                                 req->rq_ptlrpcs_restart = 1;
1303                                 rc = 0;
1304                         } else
1305                                 rc = -EINVAL;
1306
1307                         GOTO(proc_out, rc);
1308                 }
1309
1310                 req->rq_repmsg = (struct lustre_msg *) vp;
1311                 req->rq_replen = plain_text.len;
1312
1313                 rc = 0;
1314 proc_out:
1315                 gss_put_ctx(ctx);
1316                 break;
1317         default:
1318                 CERROR("unknown gss proc %d\n", proc);
1319                 rc = -EPROTO;
1320         }
1321
1322         RETURN(rc);
1323 }
1324
1325 static void destroy_gss_context(struct ptlrpc_cred *cred)
1326 {
1327         struct ptlrpcs_wire_hdr *hdr;
1328         struct lustre_msg       *lmsg;
1329         struct gss_cred         *gcred;
1330         struct ptlrpc_request    req;
1331         struct obd_import       *imp;
1332         __u32                   *vp, lmsg_size;
1333         struct ptlrpc_request   *raw_req = NULL;
1334         const int                repbuf_len = 256;
1335         char                    *repbuf;
1336         int                      replen, rc;
1337         ENTRY;
1338
1339         imp = cred->pc_sec->ps_import;
1340         LASSERT(imp);
1341
1342         if (test_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags) ||
1343             !test_bit(PTLRPC_CRED_UPTODATE_BIT, &cred->pc_flags)) {
1344                 CDEBUG(D_SEC, "Destroy dead cred %p(%u@%s)\n",
1345                        cred, cred->pc_uid, imp->imp_target_uuid.uuid);
1346                 EXIT;
1347                 return;
1348         }
1349
1350         might_sleep();
1351
1352         /* cred's refcount is 0, steal one */
1353         atomic_inc(&cred->pc_refcount);
1354
1355         gcred = container_of(cred, struct gss_cred, gc_base);
1356         gcred->gc_ctx->gc_proc = PTLRPCS_GSS_PROC_DESTROY;
1357
1358         CDEBUG(D_SEC, "client destroy gss cred %p(%u@%s)\n",
1359                gcred, cred->pc_uid, imp->imp_target_uuid.uuid);
1360
1361         lmsg_size = lustre_msg_size(0, NULL);
1362         req.rq_req_secflvr = cred->pc_sec->ps_flavor;
1363         req.rq_cred = cred;
1364         req.rq_reqbuf_len = sizeof(*hdr) + lmsg_size +
1365                             ptlrpcs_est_req_payload(&req, lmsg_size);
1366
1367         OBD_ALLOC(req.rq_reqbuf, req.rq_reqbuf_len);
1368         if (!req.rq_reqbuf) {
1369                 CERROR("Fail to alloc reqbuf, cancel anyway\n");
1370                 atomic_dec(&cred->pc_refcount);
1371                 EXIT;
1372                 return;
1373         }
1374
1375         /* wire hdr */
1376         hdr = buf_to_sec_hdr(req.rq_reqbuf);
1377         hdr->flavor  = cpu_to_le32(PTLRPCS_FLVR_GSS_AUTH);
1378         hdr->msg_len = cpu_to_le32(lmsg_size);
1379         hdr->sec_len = cpu_to_le32(0);
1380
1381         /* lustre message */
1382         lmsg = buf_to_lustre_msg(req.rq_reqbuf);
1383         lustre_init_msg(lmsg, 0, NULL, NULL);
1384         lmsg->handle   = imp->imp_remote_handle;
1385         lmsg->type     = PTL_RPC_MSG_REQUEST;
1386         lmsg->opc      = SEC_FINI;
1387         lmsg->flags    = 0;
1388         lmsg->conn_cnt = imp->imp_conn_cnt;
1389         /* add this for randomize */
1390         get_random_bytes(&lmsg->last_xid, sizeof(lmsg->last_xid));
1391         get_random_bytes(&lmsg->transno, sizeof(lmsg->transno));
1392
1393         vp = (__u32 *) req.rq_reqbuf;
1394
1395         req.rq_cred = cred;
1396         req.rq_reqmsg = buf_to_lustre_msg(req.rq_reqbuf);
1397         req.rq_reqlen = lmsg_size;
1398         req.rq_reqdata_len = sizeof(*hdr) + lmsg_size;
1399
1400         if (gss_cred_sign(cred, &req)) {
1401                 CERROR("failed to sign, cancel anyway\n");
1402                 atomic_dec(&cred->pc_refcount);
1403                 goto exit;
1404         }
1405         atomic_dec(&cred->pc_refcount);
1406
1407         OBD_ALLOC(repbuf, repbuf_len);
1408         if (!repbuf)
1409                 goto exit;
1410
1411         raw_req = ptl_do_rawrpc(imp, req.rq_reqbuf, req.rq_reqbuf_len,
1412                                 req.rq_reqdata_len, repbuf, repbuf_len, &replen,
1413                                 SECFINI_RPC_TIMEOUT, &rc);
1414         if (!raw_req)
1415                 OBD_FREE(repbuf, repbuf_len);
1416
1417 exit:
1418         if (raw_req == NULL)
1419                 OBD_FREE(req.rq_reqbuf, req.rq_reqbuf_len);
1420         else
1421                 rawrpc_req_finished(raw_req);
1422         EXIT;
1423 }
1424
1425 static void gss_cred_destroy(struct ptlrpc_cred *cred)
1426 {
1427         struct gss_cred *gcred;
1428         ENTRY;
1429
1430         LASSERT(cred);
1431         LASSERT(!atomic_read(&cred->pc_refcount));
1432
1433         gcred = container_of(cred, struct gss_cred, gc_base);
1434         if (gcred->gc_ctx) {
1435                 destroy_gss_context(cred);
1436                 gss_put_ctx(gcred->gc_ctx);
1437         }
1438
1439         CDEBUG(D_SEC, "sec.gss %p: destroy cred %p\n", cred->pc_sec, gcred);
1440
1441         OBD_FREE(gcred, sizeof(*gcred));
1442         EXIT;
1443 }
1444
1445 static struct ptlrpc_credops gss_credops = {
1446         .refresh        = gss_cred_refresh,
1447         .match          = gss_cred_match,
1448         .sign           = gss_cred_sign,
1449         .verify         = gss_cred_verify,
1450         .seal           = gss_cred_seal,
1451         .unseal         = gss_cred_unseal,
1452         .destroy        = gss_cred_destroy,
1453 };
1454
1455 #ifdef __KERNEL__
1456 /*******************************************
1457  * rpc_pipe APIs                           *
1458  *******************************************/
1459 static ssize_t
1460 gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
1461                 char *dst, size_t buflen)
1462 {
1463         char *data = (char *)msg->data + msg->copied;
1464         ssize_t mlen = msg->len;
1465         ssize_t left;
1466         ENTRY;
1467
1468         if (mlen > buflen)
1469                 mlen = buflen;
1470         left = copy_to_user(dst, data, mlen);
1471         if (left < 0) {
1472                 msg->errno = left;
1473                 RETURN(left);
1474         }
1475         mlen -= left;
1476         msg->copied += mlen;
1477         msg->errno = 0;
1478         RETURN(mlen);
1479 }
1480
1481 static ssize_t
1482 gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
1483 {
1484         char *buf;
1485         const int bufsize = 1024;
1486         rawobj_t obj;
1487         struct inode *inode = filp->f_dentry->d_inode;
1488         struct rpc_inode *rpci = RPC_I(inode);
1489         struct obd_import *import;
1490         struct ptlrpc_sec *sec;
1491         struct gss_sec *gsec;
1492         char *obdname;
1493         struct gss_api_mech *mech;
1494         struct vfs_cred vcred = { 0 };
1495         struct ptlrpc_cred *cred;
1496         struct gss_upcall_msg *gss_msg;
1497         struct gss_upcall_msg_data gmd = { 0 };
1498         struct gss_cl_ctx *ctx = NULL;
1499         ssize_t left;
1500         int err, gss_err;
1501         ENTRY;
1502
1503         if (mlen > bufsize) {
1504                 CERROR("mlen %ld > bufsize %d\n", (long)mlen, bufsize);
1505                 RETURN(-ENOSPC);
1506         }
1507
1508         OBD_ALLOC(buf, bufsize);
1509         if (!buf) {
1510                 CERROR("alloc mem failed\n");
1511                 RETURN(-ENOMEM);
1512         }
1513
1514         left = copy_from_user(buf, src, mlen);
1515         if (left)
1516                 GOTO(err_free, err = -EFAULT);
1517
1518         obj.data = (unsigned char *)buf;
1519         obj.len = mlen;
1520
1521         LASSERT(rpci->private);
1522         gsec = (struct gss_sec *)rpci->private;
1523         sec = &gsec->gs_base;
1524         LASSERT(sec->ps_import);
1525         import = class_import_get(sec->ps_import);
1526         LASSERT(import->imp_obd);
1527         obdname = import->imp_obd->obd_name;
1528         mech = gsec->gs_mech;
1529
1530         err = gss_parse_init_downcall(mech, &obj, &ctx, &gmd, &gss_err);
1531         if (err)
1532                 CERROR("parse init downcall err %d\n", err);
1533
1534         vcred.vc_pag = gmd.gum_pag;
1535         vcred.vc_uid = gmd.gum_uid;
1536
1537         cred = ptlrpcs_cred_lookup(sec, &vcred);
1538         if (!cred) {
1539                 CWARN("didn't find cred for uid %u\n", vcred.vc_uid);
1540                 GOTO(err, err = -EINVAL);
1541         }
1542
1543         if (err || gss_err) {
1544                 set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
1545                 if (err != -ERESTART || gss_err != 0)
1546                         set_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags);
1547                 CERROR("cred %p: rpc err %d, gss err 0x%x, fatal %d\n",
1548                        cred, err, gss_err,
1549                        (test_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags) != 0));
1550         } else {
1551                 CDEBUG(D_SEC, "get initial ctx:\n");
1552                 gss_cred_set_ctx(cred, ctx);
1553         }
1554
1555         spin_lock(&gsec->gs_lock);
1556         gss_msg = gss_find_upcall(gsec, obdname, &gmd);
1557         if (gss_msg) {
1558                 gss_unhash_msg_nolock(gss_msg);
1559                 spin_unlock(&gsec->gs_lock);
1560                 gss_release_msg(gss_msg);
1561         } else
1562                 spin_unlock(&gsec->gs_lock);
1563
1564         ptlrpcs_cred_put(cred, 1);
1565         class_import_put(import);
1566         OBD_FREE(buf, bufsize);
1567         RETURN(mlen);
1568 err:
1569         if (ctx)
1570                 gss_destroy_ctx(ctx);
1571         class_import_put(import);
1572 err_free:
1573         OBD_FREE(buf, bufsize);
1574         CDEBUG(D_SEC, "gss_pipe_downcall returning %d\n", err);
1575         RETURN(err);
1576 }
1577
1578 static
1579 void gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
1580 {
1581         struct gss_upcall_msg *gmsg;
1582         static unsigned long ratelimit;
1583         ENTRY;
1584
1585         LASSERT(list_empty(&msg->list));
1586
1587         if (msg->errno >= 0) {
1588                 EXIT;
1589                 return;
1590         }
1591
1592         gmsg = container_of(msg, struct gss_upcall_msg, gum_base);
1593         CDEBUG(D_SEC, "destroy gmsg %p\n", gmsg);
1594         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
1595         atomic_inc(&gmsg->gum_refcount);
1596         gss_unhash_msg(gmsg);
1597         if (msg->errno == -ETIMEDOUT || msg->errno == -EPIPE) {
1598                 unsigned long now = get_seconds();
1599                 if (time_after(now, ratelimit)) {
1600                         CWARN("sec.gss upcall timed out.\n"
1601                               "Please check user daemon is running!\n");
1602                         ratelimit = now + 15;
1603                 }
1604         }
1605         gss_release_msg(gmsg);
1606         EXIT;
1607 }
1608
1609 static
1610 void gss_pipe_release(struct inode *inode)
1611 {
1612         struct rpc_inode *rpci = RPC_I(inode);
1613         struct ptlrpc_sec *sec;
1614         struct gss_sec *gsec;
1615         ENTRY;
1616
1617         gsec = (struct gss_sec *)rpci->private;
1618         sec = &gsec->gs_base;
1619         spin_lock(&gsec->gs_lock);
1620         while (!list_empty(&gsec->gs_upcalls)) {
1621                 struct gss_upcall_msg *gmsg;
1622
1623                 gmsg = list_entry(gsec->gs_upcalls.next,
1624                                   struct gss_upcall_msg, gum_list);
1625                 LASSERT(list_empty(&gmsg->gum_base.list));
1626                 gmsg->gum_base.errno = -EPIPE;
1627                 atomic_inc(&gmsg->gum_refcount);
1628                 gss_unhash_msg_nolock(gmsg);
1629                 gss_release_msg(gmsg);
1630         }
1631         spin_unlock(&gsec->gs_lock);
1632         EXIT;
1633 }
1634
1635 static struct rpc_pipe_ops gss_upcall_ops = {
1636         .upcall         = gss_pipe_upcall,
1637         .downcall       = gss_pipe_downcall,
1638         .destroy_msg    = gss_pipe_destroy_msg,
1639         .release_pipe   = gss_pipe_release,
1640 };
1641 #endif /* __KERNEL__ */
1642
1643 /*********************************************
1644  * GSS security APIs                         *
1645  *********************************************/
1646
1647 static
1648 struct ptlrpc_sec* gss_create_sec(__u32 flavor,
1649                                   const char *pipe_dir,
1650                                   void *pipe_data)
1651 {
1652         struct gss_sec *gsec;
1653         struct ptlrpc_sec *sec;
1654         uid_t save_uid;
1655
1656 #ifdef __KERNEL__
1657         char *pos;
1658         int   pipepath_len;
1659 #endif
1660         ENTRY;
1661
1662         LASSERT(SEC_FLAVOR_MAJOR(flavor) == PTLRPCS_FLVR_MAJOR_GSS);
1663
1664         OBD_ALLOC(gsec, sizeof(*gsec));
1665         if (!gsec) {
1666                 CERROR("can't alloc gsec\n");
1667                 RETURN(NULL);
1668         }
1669
1670         gsec->gs_mech = kgss_subflavor_to_mech(SEC_FLAVOR_SUB(flavor));
1671         if (!gsec->gs_mech) {
1672                 CERROR("subflavor 0x%x not found\n", flavor);
1673                 goto err_free;
1674         }
1675
1676         /* initialize gss sec */
1677 #ifdef __KERNEL__
1678         INIT_LIST_HEAD(&gsec->gs_upcalls);
1679         spin_lock_init(&gsec->gs_lock);
1680
1681         pipepath_len = strlen(LUSTRE_PIPEDIR) + strlen(pipe_dir) +
1682                        strlen(gsec->gs_mech->gm_name) + 3;
1683         OBD_ALLOC(gsec->gs_pipepath, pipepath_len);
1684         if (!gsec->gs_pipepath)
1685                 goto err_mech_put;
1686
1687         /* pipe rpc require root permission */
1688         save_uid = current->fsuid;
1689         current->fsuid = 0;
1690
1691         sprintf(gsec->gs_pipepath, LUSTRE_PIPEDIR"/%s", pipe_dir);
1692         if (IS_ERR(rpc_mkdir(gsec->gs_pipepath, NULL))) {
1693                 CERROR("can't make pipedir %s\n", gsec->gs_pipepath);
1694                 goto err_free_path;
1695         }
1696
1697         sprintf(gsec->gs_pipepath, LUSTRE_PIPEDIR"/%s/%s", pipe_dir,
1698                 gsec->gs_mech->gm_name); 
1699         gsec->gs_depipe = rpc_mkpipe(gsec->gs_pipepath, gsec,
1700                                      &gss_upcall_ops, RPC_PIPE_WAIT_FOR_OPEN);
1701         if (IS_ERR(gsec->gs_depipe)) {
1702                 CERROR("failed to make rpc_pipe %s: %ld\n",
1703                         gsec->gs_pipepath, PTR_ERR(gsec->gs_depipe));
1704                 goto err_rmdir;
1705         }
1706         CDEBUG(D_SEC, "gss sec %p, pipe path %s\n", gsec, gsec->gs_pipepath);
1707 #endif
1708
1709         sec = &gsec->gs_base;
1710         sec->ps_expire = GSS_CREDCACHE_EXPIRE;
1711         sec->ps_nextgc = get_seconds() + sec->ps_expire;
1712         sec->ps_flags = 0;
1713
1714         current->fsuid = save_uid;
1715
1716         CDEBUG(D_SEC, "Create sec.gss %p\n", gsec);
1717         RETURN(sec);
1718
1719 #ifdef __KERNEL__
1720 err_rmdir:
1721         pos = strrchr(gsec->gs_pipepath, '/');
1722         LASSERT(pos);
1723         *pos = 0;
1724         rpc_rmdir(gsec->gs_pipepath);
1725 err_free_path:
1726         current->fsuid = save_uid;
1727         OBD_FREE(gsec->gs_pipepath, pipepath_len);
1728 err_mech_put:
1729 #endif
1730         kgss_mech_put(gsec->gs_mech);
1731 err_free:
1732         OBD_FREE(gsec, sizeof(*gsec));
1733         RETURN(NULL);
1734 }
1735
1736 static
1737 void gss_destroy_sec(struct ptlrpc_sec *sec)
1738 {
1739         struct gss_sec *gsec;
1740 #ifdef __KERNEL__
1741         char *pos;
1742         int   pipepath_len;
1743 #endif
1744         ENTRY;
1745
1746         gsec = container_of(sec, struct gss_sec, gs_base);
1747         CDEBUG(D_SEC, "Destroy sec.gss %p\n", gsec);
1748
1749         LASSERT(gsec->gs_mech);
1750         LASSERT(!atomic_read(&sec->ps_refcount));
1751         LASSERT(!atomic_read(&sec->ps_credcount));
1752 #ifdef __KERNEL__
1753         pipepath_len = strlen(gsec->gs_pipepath) + 1;
1754         rpc_unlink(gsec->gs_pipepath);
1755         pos = strrchr(gsec->gs_pipepath, '/');
1756         LASSERT(pos);
1757         *pos = 0;
1758         rpc_rmdir(gsec->gs_pipepath);
1759         OBD_FREE(gsec->gs_pipepath, pipepath_len);
1760 #endif
1761
1762         kgss_mech_put(gsec->gs_mech);
1763         OBD_FREE(gsec, sizeof(*gsec));
1764         EXIT;
1765 }
1766
1767 static
1768 struct ptlrpc_cred * gss_create_cred(struct ptlrpc_sec *sec,
1769                                      struct vfs_cred *vcred)
1770 {
1771         struct gss_cred *gcred;
1772         struct ptlrpc_cred *cred;
1773         ENTRY;
1774
1775         OBD_ALLOC(gcred, sizeof(*gcred));
1776         if (!gcred)
1777                 RETURN(NULL);
1778
1779         cred = &gcred->gc_base;
1780         INIT_LIST_HEAD(&cred->pc_hash);
1781         atomic_set(&cred->pc_refcount, 0);
1782         cred->pc_sec = sec;
1783         cred->pc_ops = &gss_credops;
1784         cred->pc_expire = 0;
1785         cred->pc_flags = 0;
1786         cred->pc_pag = vcred->vc_pag;
1787         cred->pc_uid = vcred->vc_uid;
1788         CDEBUG(D_SEC, "create a gss cred at %p("LPU64"/%u)\n",
1789                cred, vcred->vc_pag, vcred->vc_uid);
1790
1791         RETURN(cred);
1792 }
1793
1794 static int gss_estimate_payload(struct ptlrpc_sec *sec,
1795                                 struct ptlrpc_request *req,
1796                                 int msgsize)
1797 {
1798         switch (SEC_FLAVOR_SVC(req->rq_req_secflvr)) {
1799         case PTLRPCS_SVC_AUTH:
1800                 return GSS_MAX_AUTH_PAYLOAD;
1801         case PTLRPCS_SVC_PRIV:
1802                 return size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1803                                     GSS_PRIVBUF_PREFIX_LEN +
1804                                     GSS_PRIVBUF_SUFFIX_LEN);
1805         default:
1806                 LBUG();
1807                 return 0;
1808         }
1809 }
1810
1811 static int gss_alloc_reqbuf(struct ptlrpc_sec *sec,
1812                             struct ptlrpc_request *req,
1813                             int lmsg_size)
1814 {
1815         int msg_payload, sec_payload;
1816         int privacy, rc;
1817         ENTRY;
1818
1819         /* In PRIVACY mode, lustre message is always 0 (already encoded into
1820          * security payload).
1821          */
1822         privacy = (SEC_FLAVOR_SVC(req->rq_req_secflvr) == PTLRPCS_SVC_PRIV);
1823         msg_payload = privacy ? 0 : lmsg_size;
1824         sec_payload = gss_estimate_payload(sec, req, lmsg_size);
1825
1826         rc = sec_alloc_reqbuf(sec, req, msg_payload, sec_payload);
1827         if (rc)
1828                 return rc;
1829
1830         if (privacy) {
1831                 int buflen = lmsg_size + GSS_PRIVBUF_PREFIX_LEN +
1832                              GSS_PRIVBUF_SUFFIX_LEN;
1833                 char *buf;
1834
1835                 OBD_ALLOC(buf, buflen);
1836                 if (!buf) {
1837                         CERROR("Fail to alloc %d\n", buflen);
1838                         sec_free_reqbuf(sec, req);
1839                         RETURN(-ENOMEM);
1840                 }
1841                 req->rq_reqmsg = (struct lustre_msg *)
1842                                         (buf + GSS_PRIVBUF_PREFIX_LEN);
1843         }
1844
1845         RETURN(0);
1846 }
1847
1848 static void gss_free_reqbuf(struct ptlrpc_sec *sec,
1849                             struct ptlrpc_request *req)
1850 {
1851         char *buf;
1852         int privacy;
1853         ENTRY;
1854
1855         LASSERT(req->rq_reqmsg);
1856         LASSERT(req->rq_reqlen);
1857
1858         privacy = SEC_FLAVOR_SVC(req->rq_req_secflvr) == PTLRPCS_SVC_PRIV;
1859         if (privacy) {
1860                 buf = (char *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1861                 LASSERT(buf < req->rq_reqbuf ||
1862                         buf >= req->rq_reqbuf + req->rq_reqbuf_len);
1863                 OBD_FREE(buf, req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1864                               GSS_PRIVBUF_SUFFIX_LEN);
1865                 req->rq_reqmsg = NULL;
1866         }
1867
1868         sec_free_reqbuf(sec, req);
1869 }
1870
1871 static struct ptlrpc_secops gss_secops = {
1872         .create_sec             = gss_create_sec,
1873         .destroy_sec            = gss_destroy_sec,
1874         .create_cred            = gss_create_cred,
1875         .est_req_payload        = gss_estimate_payload,
1876         .est_rep_payload        = gss_estimate_payload,
1877         .alloc_reqbuf           = gss_alloc_reqbuf,
1878         .free_reqbuf            = gss_free_reqbuf,
1879 };
1880
1881 static struct ptlrpc_sec_type gss_type = {
1882         .pst_owner      = THIS_MODULE,
1883         .pst_name       = "sec.gss",
1884         .pst_inst       = ATOMIC_INIT(0),
1885         .pst_flavor     = PTLRPCS_FLVR_MAJOR_GSS,
1886         .pst_ops        = &gss_secops,
1887 };
1888
1889 extern int
1890 (*lustre_secinit_downcall_handler)(char *buffer, unsigned long count);
1891
1892 int __init ptlrpcs_gss_init(void)
1893 {
1894         int rc;
1895
1896         rc = ptlrpcs_register(&gss_type);
1897         if (rc)
1898                 return rc;
1899
1900 #ifdef __KERNEL__
1901         gss_svc_init();
1902
1903         rc = PTR_ERR(rpc_mkdir(LUSTRE_PIPEDIR, NULL));
1904         if (IS_ERR((void *)rc) && rc != -EEXIST) {
1905                 CERROR("fail to make rpcpipedir for lustre\n");
1906                 gss_svc_exit();
1907                 ptlrpcs_unregister(&gss_type);
1908                 return -1;
1909         }
1910         rc = 0;
1911 #else
1912 #endif
1913         rc = init_kerberos_module();
1914         if (rc) {
1915                 ptlrpcs_unregister(&gss_type);
1916         }
1917
1918         lustre_secinit_downcall_handler = gss_send_secinit_rpc;
1919
1920         return rc;
1921 }
1922
1923 #ifdef __KERNEL__
1924 static void __exit ptlrpcs_gss_exit(void)
1925 {
1926         lustre_secinit_downcall_handler = NULL;
1927
1928         cleanup_kerberos_module();
1929         rpc_rmdir(LUSTRE_PIPEDIR);
1930         gss_svc_exit();
1931         ptlrpcs_unregister(&gss_type);
1932 }
1933 #endif
1934
1935 MODULE_AUTHOR("Cluster File Systems, Inc. <info@clusterfs.com>");
1936 MODULE_DESCRIPTION("GSS Security module for Lustre");
1937 MODULE_LICENSE("GPL");
1938
1939 module_init(ptlrpcs_gss_init);
1940 module_exit(ptlrpcs_gss_exit);