Whamcloud - gitweb
c901197240d5154b147e1425ada80542c352f18e
[fs/lustre-release.git] / lustre / sec / gss / sec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * linux/net/sunrpc/auth_gss.c
12  *
13  * RPCSEC_GSS client authentication.
14  *
15  *  Copyright (c) 2000 The Regents of the University of Michigan.
16  *  All rights reserved.
17  *
18  *  Dug Song       <dugsong@monkey.org>
19  *  Andy Adamson   <andros@umich.edu>
20  *
21  *  Redistribution and use in source and binary forms, with or without
22  *  modification, are permitted provided that the following conditions
23  *  are met:
24  *
25  *  1. Redistributions of source code must retain the above copyright
26  *     notice, this list of conditions and the following disclaimer.
27  *  2. Redistributions in binary form must reproduce the above copyright
28  *     notice, this list of conditions and the following disclaimer in the
29  *     documentation and/or other materials provided with the distribution.
30  *  3. Neither the name of the University nor the names of its
31  *     contributors may be used to endorse or promote products derived
32  *     from this software without specific prior written permission.
33  *
34  *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
35  *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
36  *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37  *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
38  *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
39  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
40  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
41  *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
42  *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
43  *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
44  *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45  *
46  */
47
48 #ifndef EXPORT_SYMTAB
49 # define EXPORT_SYMTAB
50 #endif
51 #define DEBUG_SUBSYSTEM S_SEC
52 #ifdef __KERNEL__
53 #include <linux/init.h>
54 #include <linux/module.h>
55 #include <linux/slab.h>
56 #include <linux/dcache.h>
57 #include <linux/fs.h>
58 #include <linux/random.h>
59 /* for rpc_pipefs */
60 struct rpc_clnt;
61 #include <linux/sunrpc/rpc_pipe_fs.h>
62 #else
63 #include <liblustre.h>
64 #endif
65
66 #include <libcfs/kp30.h>
67 #include <linux/obd.h>
68 #include <linux/obd_class.h>
69 #include <linux/obd_support.h>
70 #include <linux/lustre_idl.h>
71 #include <linux/lustre_net.h>
72 #include <linux/lustre_import.h>
73 #include <linux/lustre_sec.h>
74
75 #include "gss_err.h"
76 #include "gss_internal.h"
77 #include "gss_api.h"
78
79 #define LUSTRE_PIPEDIR          "/lustre"
80
81 #define GSS_CREDCACHE_EXPIRE    (30 * 60)          /* 30 minute */
82
83 #define GSS_TIMEOUT_DELTA       (5)
84 #define CRED_REFRESH_UPCALL_TIMEOUT                             \
85         ({                                                      \
86                 int timeout = obd_timeout - GSS_TIMEOUT_DELTA;  \
87                                                                 \
88                 if (timeout < GSS_TIMEOUT_DELTA * 2)            \
89                         timeout = GSS_TIMEOUT_DELTA * 2;        \
90                 timeout;                                        \
91         })
92 #define SECINIT_RPC_TIMEOUT                                     \
93         ({                                                      \
94                 int timeout = CRED_REFRESH_UPCALL_TIMEOUT -     \
95                               GSS_TIMEOUT_DELTA;                \
96                 if (timeout < GSS_TIMEOUT_DELTA)                \
97                         timeout = GSS_TIMEOUT_DELTA;            \
98                 timeout;                                        \
99         })
100 #define SECFINI_RPC_TIMEOUT     (GSS_TIMEOUT_DELTA)
101
102
103 /**********************************************
104  * gss security init/fini helper              *
105  **********************************************/
106
107 static int secinit_compose_request(struct obd_import *imp,
108                                    char *buf, int bufsize,
109                                    int lustre_srv,
110                                    uid_t uid, gid_t gid,
111                                    long token_size,
112                                    char __user *token)
113 {
114         struct ptlrpcs_wire_hdr *hdr;
115         struct lustre_msg       *lmsg;
116         struct mds_req_sec_desc *secdesc;
117         int                      size = sizeof(*secdesc);
118         __u32                    lmsg_size, *p;
119         int                      rc;
120
121         lmsg_size = lustre_msg_size(1, &size);
122
123         if (sizeof(*hdr) + lmsg_size + size_round(token_size) > bufsize) {
124                 CERROR("token size %ld too large\n", token_size);
125                 return -EINVAL;
126         }
127
128         /* security wire hdr */
129         hdr = buf_to_sec_hdr(buf);
130         hdr->flavor  = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
131         hdr->msg_len = cpu_to_le32(lmsg_size);
132         hdr->sec_len = cpu_to_le32(8 * 4 + token_size);
133
134         /* lustre message & secdesc */
135         lmsg = buf_to_lustre_msg(buf);
136
137         lustre_init_msg(lmsg, 1, &size, NULL);
138         secdesc = lustre_msg_buf(lmsg, 0, size);
139         secdesc->rsd_uid = secdesc->rsd_fsuid = uid;
140         secdesc->rsd_gid = secdesc->rsd_fsgid = gid;
141         secdesc->rsd_cap = secdesc->rsd_ngroups = 0;
142
143         lmsg->handle   = imp->imp_remote_handle;
144         lmsg->type     = PTL_RPC_MSG_REQUEST;
145         lmsg->opc      = SEC_INIT;
146         lmsg->flags    = 0;
147         lmsg->conn_cnt = imp->imp_conn_cnt;
148
149         p = (__u32 *) (buf + sizeof(*hdr) + lmsg_size);
150
151         /* gss hdr */
152         *p++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);     /* gss version */
153         *p++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);         /* subflavor */
154         *p++ = cpu_to_le32(PTLRPCS_GSS_PROC_INIT);      /* proc */
155         *p++ = cpu_to_le32(0);                          /* seq */
156         *p++ = cpu_to_le32(PTLRPCS_GSS_SVC_NONE);       /* service */
157         *p++ = cpu_to_le32(0);                          /* context handle */
158
159         /* plus lustre svc type */
160         *p++ = cpu_to_le32(lustre_srv);
161
162         /* now the token part */
163         *p++ = cpu_to_le32((__u32) token_size);
164         LASSERT(((char *)p - buf) + token_size <= bufsize);
165
166         rc = copy_from_user(p, token, token_size);
167         if (rc) {
168                 CERROR("can't copy token\n");
169                 return -EFAULT;
170         }
171
172         rc = size_round(((char *)p - buf) + token_size);
173         return rc;
174 }
175
176 static int secinit_parse_reply(char *repbuf, int replen,
177                                char __user *outbuf, long outlen)
178 {
179         __u32                   *p = (__u32 *)repbuf;
180         struct ptlrpcs_wire_hdr *hdr = (struct ptlrpcs_wire_hdr *) repbuf;
181         __u32                    lmsg_len, sec_len, status;
182         __u32                    major, minor, seq, obj_len, round_len;
183         __u32                    effective = 0;
184
185         if (replen <= (4 + 6) * 4) {
186                 CERROR("reply size %d too small\n", replen);
187                 return -EINVAL;
188         }
189
190         hdr->flavor = le32_to_cpu(hdr->flavor);
191         hdr->msg_len = le32_to_cpu(hdr->msg_len);
192         hdr->sec_len = le32_to_cpu(hdr->sec_len);
193
194         lmsg_len = le32_to_cpu(p[2]);
195         sec_len = le32_to_cpu(p[3]);
196
197         /* sanity checks */
198         if (hdr->flavor != PTLRPCS_FLVR_GSS_NONE) {
199                 CERROR("unexpected reply\n");
200                 return -EINVAL;
201         }
202         if (hdr->msg_len % 8 ||
203             sizeof(*hdr) + hdr->msg_len + hdr->sec_len > replen) {
204                 CERROR("unexpected reply\n");
205                 return -EINVAL;
206         }
207         if (hdr->sec_len > outlen) {
208                 CERROR("outbuf too small\n");
209                 return -EINVAL;
210         }
211
212         p = (__u32 *) buf_to_sec_data(repbuf);
213         effective = 0;
214
215         p += 2; /* skip the leading unused bytes */
216         seq = le32_to_cpu(*p++);
217         major = le32_to_cpu(*p++);
218         minor = le32_to_cpu(*p++);
219         status = 0;
220
221         effective += 4 * 4;
222
223         if (copy_to_user(outbuf, &status, 4))
224                 return -EFAULT;
225         outbuf += 4;
226         if (copy_to_user(outbuf, &major, 4))
227                 return -EFAULT;
228         outbuf += 4;
229         if (copy_to_user(outbuf, &minor, 4))
230                 return -EFAULT;
231         outbuf += 4;
232         if (copy_to_user(outbuf, &seq, 4))
233                 return -EFAULT;
234         outbuf += 4;
235
236         obj_len = le32_to_cpu(*p++);
237         round_len = (obj_len + 3) & ~ 3;
238         if (copy_to_user(outbuf, &obj_len, 4))
239                 return -EFAULT;
240         outbuf += 4;
241         if (copy_to_user(outbuf, (char *)p, round_len))
242                 return -EFAULT;
243         p += round_len / 4;
244         outbuf += round_len;
245         effective += 4 + round_len;
246
247         obj_len = le32_to_cpu(*p++);
248         round_len = (obj_len + 3) & ~ 3;
249         if (copy_to_user(outbuf, &obj_len, 4))
250                 return -EFAULT;
251         outbuf += 4;
252         if (copy_to_user(outbuf, (char *)p, round_len))
253                 return -EFAULT;
254         p += round_len / 4;
255         outbuf += round_len;
256         effective += 4 + round_len;
257
258         return effective;
259 }
260
261 /* XXX move to where lgssd could see */
262 struct lgssd_ioctl_param {
263         int             version;        /* in   */
264         char           *uuid;           /* in   */
265         int             lustre_svc;     /* in   */
266         uid_t           uid;            /* in   */
267         gid_t           gid;            /* in   */
268         long            send_token_size;/* in   */
269         char           *send_token;     /* in   */
270         long            reply_buf_size; /* in   */
271         char           *reply_buf;      /* in   */
272         long            status;         /* out  */
273         long            reply_length;   /* out  */
274 };
275
276 static int gss_send_secinit_rpc(__user char *buffer, unsigned long count)
277 {
278         struct obd_import        *imp;
279         struct ptlrpc_request    *request = NULL;
280         struct lgssd_ioctl_param  param;
281         const int                 reqbuf_size = 1024;
282         const int                 repbuf_size = 1024;
283         char                     *reqbuf, *repbuf;
284         struct obd_device        *obd;
285         char                      obdname[64];
286         long                      lsize;
287         int                       rc, reqlen, replen;
288
289         if (count != sizeof(param)) {
290                 CERROR("ioctl size %lu, expect %d, please check lgssd version\n",
291                         count, sizeof(param));
292                 RETURN(-EINVAL);
293         }
294         if (copy_from_user(&param, buffer, sizeof(param))) {
295                 CERROR("failed copy data from lgssd\n");
296                 RETURN(-EFAULT);
297         }
298
299         if (param.version != GSSD_INTERFACE_VERSION) {
300                 CERROR("gssd interface version %d (expect %d)\n",
301                         param.version, GSSD_INTERFACE_VERSION);
302                 RETURN(-EINVAL);
303         }
304
305         /* take name */
306         if (strncpy_from_user(obdname, param.uuid,
307                               sizeof(obdname)) <= 0) {
308                 CERROR("Invalid obdname pointer\n");
309                 RETURN(-EFAULT);
310         }
311
312         obd = class_name2obd(obdname);
313         if (!obd) {
314                 CERROR("no such obd %s\n", obdname);
315                 RETURN(-EINVAL);
316         }
317
318         imp = class_import_get(obd->u.cli.cl_import);
319
320         OBD_ALLOC(reqbuf, reqbuf_size);
321         OBD_ALLOC(repbuf, reqbuf_size);
322
323         if (!reqbuf || !repbuf) {
324                 CERROR("Can't alloc buffer: %p/%p\n", reqbuf, repbuf);
325                 param.status = -ENOMEM;
326                 goto out_copy;
327         }
328
329         /* get token */
330         reqlen = secinit_compose_request(imp, reqbuf, reqbuf_size,
331                                          param.lustre_svc,
332                                          param.uid, param.gid,
333                                          param.send_token_size,
334                                          param.send_token);
335         if (reqlen < 0) {
336                 param.status = reqlen;
337                 goto out_copy;
338         }
339
340         request = ptl_do_rawrpc(imp, reqbuf, reqbuf_size, reqlen,
341                                 repbuf, repbuf_size, &replen,
342                                 SECINIT_RPC_TIMEOUT, &rc);
343         if (request == NULL || rc) {
344                 param.status = rc;
345                 goto out_copy;
346         }
347
348         if (replen > param.reply_buf_size) {
349                 CERROR("output buffer size %ld too small, need %d\n",
350                         param.reply_buf_size, replen);
351                 param.status = -EINVAL;
352                 goto out_copy;
353         }
354
355         lsize = secinit_parse_reply(repbuf, replen,
356                                     param.reply_buf, param.reply_buf_size);
357         if (lsize < 0) {
358                 param.status = (int) lsize;
359                 goto out_copy;
360         }
361
362         param.status = 0;
363         param.reply_length = lsize;
364
365 out_copy:
366         if (copy_to_user(buffer, &param, sizeof(param)))
367                 rc = -EFAULT;
368         else
369                 rc = 0;
370
371         class_import_put(imp);
372         if (request == NULL) {
373                 if (repbuf)
374                         OBD_FREE(repbuf, repbuf_size);
375                 if (reqbuf)
376                         OBD_FREE(reqbuf, reqbuf_size);
377         } else {
378                 rawrpc_req_finished(request);
379         }
380         RETURN(rc);
381 }
382
383 /**********************************************
384  * structure definitions                      *
385  **********************************************/
386 struct gss_sec {
387         struct ptlrpc_sec       gs_base;
388         struct gss_api_mech    *gs_mech;
389         spinlock_t              gs_lock;
390         struct list_head        gs_upcalls;
391         char                   *gs_pipepath;
392         struct dentry          *gs_depipe;
393 };
394
395 struct gss_upcall_msg_data {
396         __u64                           gum_pag;
397         __u32                           gum_uid;
398         __u32                           gum_svc;
399         __u32                           gum_nal;
400         __u32                           gum_netid;
401         __u64                           gum_nid;
402 };
403
404 struct gss_upcall_msg {
405         struct rpc_pipe_msg             gum_base;
406         atomic_t                        gum_refcount;
407         struct list_head                gum_list;
408         struct gss_sec                 *gum_gsec;
409         wait_queue_head_t               gum_waitq;
410         char                            gum_obdname[64];
411         struct gss_upcall_msg_data      gum_data;
412 };
413
414 #ifdef __KERNEL__
415 static rwlock_t gss_ctx_lock = RW_LOCK_UNLOCKED;
416 /**********************************************
417  * rpc_pipe upcall helpers                    *
418  **********************************************/
419 static
420 void gss_release_msg(struct gss_upcall_msg *gmsg)
421 {
422         ENTRY;
423         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
424
425         if (!atomic_dec_and_test(&gmsg->gum_refcount)) {
426                 CDEBUG(D_SEC, "gmsg %p ref %d\n", gmsg,
427                        atomic_read(&gmsg->gum_refcount));
428                 EXIT;
429                 return;
430         }
431         LASSERT(list_empty(&gmsg->gum_list));
432 #if 0
433         LASSERT(list_empty(&gmsg->gum_base.list));
434 #else
435         /* XXX */
436         if (!list_empty(&gmsg->gum_base.list)) {
437                 int error = gmsg->gum_base.errno;
438                 
439                 CWARN("msg %p: list: %p/%p/%p, copied %d, err %d, wq %d\n",
440                       gmsg, &gmsg->gum_base.list, gmsg->gum_base.list.prev,
441                       gmsg->gum_base.list.next, gmsg->gum_base.copied, error,
442                       list_empty(&gmsg->gum_waitq.task_list));
443                 LBUG();
444         }
445 #endif
446         OBD_FREE(gmsg, sizeof(*gmsg));
447         EXIT;
448 }
449
450 static void
451 gss_unhash_msg_nolock(struct gss_upcall_msg *gmsg)
452 {
453         LASSERT_SPIN_LOCKED(&gmsg->gum_gsec->gs_lock);
454
455         if (list_empty(&gmsg->gum_list))
456                 return;
457
458         list_del_init(&gmsg->gum_list);
459         wake_up(&gmsg->gum_waitq);
460         LASSERT(atomic_read(&gmsg->gum_refcount) > 1);
461         atomic_dec(&gmsg->gum_refcount);
462 }
463
464 static void
465 gss_unhash_msg(struct gss_upcall_msg *gmsg)
466 {
467         struct gss_sec *gsec = gmsg->gum_gsec;
468
469         spin_lock(&gsec->gs_lock);
470         gss_unhash_msg_nolock(gmsg);
471         spin_unlock(&gsec->gs_lock);
472 }
473
474 static
475 struct gss_upcall_msg * gss_find_upcall(struct gss_sec *gsec,
476                                         char *obdname,
477                                         struct gss_upcall_msg_data *gmd)
478 {
479         struct gss_upcall_msg *gmsg;
480         ENTRY;
481
482         LASSERT_SPIN_LOCKED(&gsec->gs_lock);
483
484         list_for_each_entry(gmsg, &gsec->gs_upcalls, gum_list) {
485                 if (memcmp(&gmsg->gum_data, gmd, sizeof(*gmd)))
486                         continue;
487                 if (strcmp(gmsg->gum_obdname, obdname))
488                         continue;
489                 LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
490                 atomic_inc(&gmsg->gum_refcount);
491                 CDEBUG(D_SEC, "found gmsg at %p: obdname %s, uid %d, ref %d\n",
492                        gmsg, obdname, gmd->gum_uid,
493                        atomic_read(&gmsg->gum_refcount));
494                 RETURN(gmsg);
495         }
496         RETURN(NULL);
497 }
498
499 static void gss_init_upcall_msg(struct gss_upcall_msg *gmsg,
500                                 struct gss_sec *gsec, char *obdname,
501                                 struct gss_upcall_msg_data *gmd)
502 {
503         struct rpc_pipe_msg *rpcmsg;
504         ENTRY;
505
506         /* 2 refs: 1 for hash, 1 for current user */
507         init_waitqueue_head(&gmsg->gum_waitq);
508         list_add(&gmsg->gum_list, &gsec->gs_upcalls);
509         atomic_set(&gmsg->gum_refcount, 2);
510         gmsg->gum_gsec = gsec;
511         strncpy(gmsg->gum_obdname, obdname, sizeof(gmsg->gum_obdname));
512         memcpy(&gmsg->gum_data, gmd, sizeof(*gmd));
513
514         rpcmsg = &gmsg->gum_base;
515         INIT_LIST_HEAD(&rpcmsg->list);
516         rpcmsg->data = &gmsg->gum_data;
517         rpcmsg->len = sizeof(gmsg->gum_data);
518         rpcmsg->copied = 0;
519         rpcmsg->errno = 0;
520         EXIT;
521 }
522 #endif /* __KERNEL__ */
523
524 /* this seems to be used only from userspace code */
525 #ifndef __KERNEL__
526 /********************************************
527  * gss cred manipulation helpers            *
528  ********************************************/
529 static
530 int gss_cred_is_uptodate_ctx(struct ptlrpc_cred *cred)
531 {
532         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
533         int res = 0;
534
535         read_lock(&gss_ctx_lock);
536         if (((cred->pc_flags & PTLRPC_CRED_FLAGS_MASK) ==
537              PTLRPC_CRED_UPTODATE) &&
538             gcred->gc_ctx)
539                 res = 1;
540         read_unlock(&gss_ctx_lock);
541         return res;
542 }
543 #endif
544
545 static inline
546 struct gss_cl_ctx *gss_get_ctx(struct gss_cl_ctx *ctx)
547 {
548         atomic_inc(&ctx->gc_refcount);
549         return ctx;
550 }
551
552 static
553 void gss_destroy_ctx(struct gss_cl_ctx *ctx)
554 {
555         ENTRY;
556
557         CDEBUG(D_SEC, "destroy cl_ctx %p\n", ctx);
558         if (ctx->gc_gss_ctx)
559                 kgss_delete_sec_context(&ctx->gc_gss_ctx);
560
561         if (ctx->gc_wire_ctx.len > 0) {
562                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
563                 ctx->gc_wire_ctx.len = 0;
564         }
565
566         OBD_FREE(ctx, sizeof(*ctx));
567 }
568
569 static
570 void gss_put_ctx(struct gss_cl_ctx *ctx)
571 {
572         if (atomic_dec_and_test(&ctx->gc_refcount))
573                 gss_destroy_ctx(ctx);
574 }
575
576 static
577 struct gss_cl_ctx *gss_cred_get_ctx(struct ptlrpc_cred *cred)
578 {
579         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
580         struct gss_cl_ctx *ctx = NULL;
581
582         read_lock(&gss_ctx_lock);
583         if (gcred->gc_ctx)
584                 ctx = gss_get_ctx(gcred->gc_ctx);
585         read_unlock(&gss_ctx_lock);
586         return ctx;
587 }
588
589 static
590 void gss_cred_set_ctx(struct ptlrpc_cred *cred, struct gss_cl_ctx *ctx)
591 {
592         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
593         struct gss_cl_ctx *old;
594         __u64 ctx_expiry;
595         ENTRY;
596
597         if (kgss_inquire_context(ctx->gc_gss_ctx, &ctx_expiry)) {
598                 CERROR("unable to get expire time\n");
599                 ctx_expiry = 1; /* make it expired now */
600         }
601         cred->pc_expire = (unsigned long) ctx_expiry;
602
603         write_lock(&gss_ctx_lock);
604         old = gcred->gc_ctx;
605         gcred->gc_ctx = ctx;
606         set_bit(PTLRPC_CRED_UPTODATE_BIT, &cred->pc_flags);
607         write_unlock(&gss_ctx_lock);
608         if (old)
609                 gss_put_ctx(old);
610
611         CDEBUG(D_SEC, "client refreshed gss cred %p(uid %u)\n",
612                cred, cred->pc_uid);
613         EXIT;
614 }
615
616 static int
617 simple_get_bytes(char **buf, __u32 *buflen, void *res, __u32 reslen)
618 {
619         if (*buflen < reslen) {
620                 CERROR("buflen %u < %u\n", *buflen, reslen);
621                 return -EINVAL;
622         }
623
624         memcpy(res, *buf, reslen);
625         *buf += reslen;
626         *buflen -= reslen;
627         return 0;
628 }
629
630 /* data passed down:
631  *  - uid
632  *  - timeout
633  *  - gc_win / error
634  *  - wire_ctx (rawobj)
635  *  - mech_ctx? (rawobj)
636  */
637 static
638 int gss_parse_init_downcall(struct gss_api_mech *gm, rawobj_t *buf,
639                             struct gss_cl_ctx **gc,
640                             struct gss_upcall_msg_data *gmd, int *gss_err)
641 {
642         char *p = (char *)buf->data;
643         struct gss_cl_ctx *ctx;
644         __u32 len = buf->len;
645         unsigned int timeout;
646         rawobj_t tmp_buf;
647         int err = -EPERM;
648         ENTRY;
649
650         *gc = NULL;
651         *gss_err = 0;
652
653         OBD_ALLOC(ctx, sizeof(*ctx));
654         if (!ctx)
655                 RETURN(-ENOMEM);
656
657         ctx->gc_proc = RPC_GSS_PROC_DATA;
658         ctx->gc_seq = 0;
659         spin_lock_init(&ctx->gc_seq_lock);
660         atomic_set(&ctx->gc_refcount,1);
661
662         if (simple_get_bytes(&p, &len, &gmd->gum_pag, sizeof(gmd->gum_pag)))
663                 goto err_free_ctx;
664         if (simple_get_bytes(&p, &len, &gmd->gum_uid, sizeof(gmd->gum_uid)))
665                 goto err_free_ctx;
666         if (simple_get_bytes(&p, &len, &gmd->gum_svc, sizeof(gmd->gum_svc)))
667                 goto err_free_ctx;
668         if (simple_get_bytes(&p, &len, &gmd->gum_nal, sizeof(gmd->gum_nal)))
669                 goto err_free_ctx;
670         if (simple_get_bytes(&p, &len, &gmd->gum_netid, sizeof(gmd->gum_netid)))
671                 goto err_free_ctx;
672         if (simple_get_bytes(&p, &len, &gmd->gum_nid, sizeof(gmd->gum_nid)))
673                 goto err_free_ctx;
674         /* FIXME: discarded timeout for now */
675         if (simple_get_bytes(&p, &len, &timeout, sizeof(timeout)))
676                 goto err_free_ctx;
677         if (simple_get_bytes(&p, &len, &ctx->gc_win, sizeof(ctx->gc_win)))
678                 goto err_free_ctx;
679
680         /* lgssd signals an error by passing ctx->gc_win = 0: */
681         if (!ctx->gc_win) {
682                 /* in which case the next 2 int are:
683                  * - rpc error
684                  * - gss error
685                  */
686                 if (simple_get_bytes(&p, &len, &err, sizeof(err))) {
687                         err = -EPERM;
688                         goto err_free_ctx;
689                 }
690                 if (simple_get_bytes(&p, &len, gss_err, sizeof(*gss_err))) {
691                         err = -EPERM;
692                         goto err_free_ctx;
693                 }
694                 if (err == 0 && *gss_err == 0) {
695                         CERROR("no error passed from downcall\n");
696                         err = -EPERM;
697                 }
698                 goto err_free_ctx;
699         }
700
701         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
702                 goto err_free_ctx;
703         if (rawobj_dup(&ctx->gc_wire_ctx, &tmp_buf)) {
704                 err = -ENOMEM;
705                 goto err_free_ctx;
706         }
707         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
708                 goto err_free_wire_ctx;
709         if (len) {
710                 CERROR("unexpected trailing %u bytes\n", len);
711                 goto err_free_wire_ctx;
712         }
713         if (kgss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx))
714                 goto err_free_wire_ctx;
715
716         *gc = ctx;
717         RETURN(0);
718
719 err_free_wire_ctx:
720         if (ctx->gc_wire_ctx.data)
721                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
722 err_free_ctx:
723         OBD_FREE(ctx, sizeof(*ctx));
724         CDEBUG(D_SEC, "err_code %d, gss code %d\n", err, *gss_err);
725         return err;
726 }
727
728 /***************************************
729  * cred APIs                           *
730  ***************************************/
731 #ifdef __KERNEL__
732 static int gss_cred_refresh(struct ptlrpc_cred *cred)
733 {
734         struct obd_import          *import;
735         struct gss_sec             *gsec;
736         struct gss_upcall_msg      *gss_msg, *gss_new;
737         struct gss_upcall_msg_data  gmd;
738         struct dentry              *dentry;
739         char                       *obdname, *obdtype;
740         wait_queue_t                wait;
741         int                         res;
742         ENTRY;
743
744         might_sleep();
745
746         /* any flags means it has been handled, do nothing */
747         if (cred->pc_flags & PTLRPC_CRED_FLAGS_MASK)
748                 RETURN(0);
749
750         LASSERT(cred->pc_sec);
751         LASSERT(cred->pc_sec->ps_import);
752         LASSERT(cred->pc_sec->ps_import->imp_obd);
753
754         import = cred->pc_sec->ps_import;
755         if (!import->imp_connection) {
756                 CERROR("import has no connection set\n");
757                 RETURN(-EINVAL);
758         }
759
760         gmd.gum_pag = cred->pc_pag;
761         gmd.gum_uid = cred->pc_uid;
762         gmd.gum_nal = import->imp_connection->c_peer.peer_ni->pni_number;
763         gmd.gum_netid = 0;
764         gmd.gum_nid = import->imp_connection->c_peer.peer_id.nid;
765
766         obdtype = import->imp_obd->obd_type->typ_name;
767         if (!strcmp(obdtype, OBD_MDC_DEVICENAME))
768                 gmd.gum_svc = LUSTRE_GSS_SVC_MDS;
769         else if (!strcmp(obdtype, OBD_OSC_DEVICENAME))
770                 gmd.gum_svc = LUSTRE_GSS_SVC_OSS;
771         else {
772                 CERROR("gss on %s?\n", obdtype);
773                 RETURN(-EINVAL);
774         }
775
776         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
777         obdname = import->imp_obd->obd_name;
778         dentry = gsec->gs_depipe;
779         gss_new = NULL;
780         res = 0;
781
782         CDEBUG(D_SEC, "Initiate gss context %p(%u@%s)\n",
783                container_of(cred, struct gss_cred, gc_base),
784                cred->pc_uid, import->imp_target_uuid.uuid);
785
786 again:
787         spin_lock(&gsec->gs_lock);
788         gss_msg = gss_find_upcall(gsec, obdname, &gmd);
789         if (gss_msg) {
790                 if (gss_new) {
791                         OBD_FREE(gss_new, sizeof(*gss_new));
792                         gss_new = NULL;
793                 }
794                 GOTO(waiting, res);
795         }
796
797         if (!gss_new) {
798                 spin_unlock(&gsec->gs_lock);
799                 OBD_ALLOC(gss_new, sizeof(*gss_new));
800                 if (!gss_new)
801                         RETURN(-ENOMEM);
802                 goto again;
803         }
804         /* so far we'v created gss_new */
805         gss_init_upcall_msg(gss_new, gsec, obdname, &gmd);
806
807         /* we'v created upcall msg, nobody else should touch the
808          * flag of this cred, unless be set as dead/expire by
809          * administrator via lctl etc.
810          */
811         if (cred->pc_flags & PTLRPC_CRED_FLAGS_MASK) {
812                 CWARN("cred %p("LPU64"/%u) was set flags %lx unexpectedly\n",
813                       cred, cred->pc_pag, cred->pc_uid, cred->pc_flags);
814                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
815                 gss_unhash_msg_nolock(gss_new);
816                 spin_unlock(&gsec->gs_lock);
817                 gss_release_msg(gss_new);
818                 RETURN(0);
819         }
820
821         /* need to make upcall now */
822         spin_unlock(&gsec->gs_lock);
823         res = rpc_queue_upcall(dentry->d_inode, &gss_new->gum_base);
824         if (res) {
825                 CERROR("rpc_queue_upcall failed: %d\n", res);
826                 gss_unhash_msg(gss_new);
827                 gss_release_msg(gss_new);
828                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
829                 RETURN(res);
830         }
831         gss_msg = gss_new;
832         spin_lock(&gsec->gs_lock);
833
834 waiting:
835         /* upcall might finish quickly */
836         if (list_empty(&gss_msg->gum_list)) {
837                 spin_unlock(&gsec->gs_lock);
838                 res = 0;
839                 goto out;
840         }
841
842         init_waitqueue_entry(&wait, current);
843         set_current_state(TASK_INTERRUPTIBLE);
844         add_wait_queue(&gss_msg->gum_waitq, &wait);
845         spin_unlock(&gsec->gs_lock);
846
847         if (gss_new)
848                 res = schedule_timeout(CRED_REFRESH_UPCALL_TIMEOUT * HZ);
849         else {
850                 schedule();
851                 res = 0;
852         }
853
854         remove_wait_queue(&gss_msg->gum_waitq, &wait);
855
856         /* - the one who refresh the cred for us should also be responsible
857          *   to set the status of cred, we can simply return.
858          * - if cred flags has been set, we also don't need to do that again,
859          *   no matter signal pending or timeout etc.
860          */
861         if (!gss_new || cred->pc_flags & PTLRPC_CRED_FLAGS_MASK)
862                 goto out;
863
864         if (signal_pending(current)) {
865                 CERROR("%s: cred %p: interrupted upcall\n",
866                        current->comm, cred);
867                 cred->pc_flags |= PTLRPC_CRED_DEAD | PTLRPC_CRED_ERROR;
868                 res = -EINTR;
869         } else if (res == 0) {
870                 CERROR("cred %p: upcall timedout\n", cred);
871                 set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
872                 res = -ETIMEDOUT;
873         } else
874                 res = 0;
875
876 out:
877         gss_release_msg(gss_msg);
878
879         RETURN(res);
880 }
881 #else /* !__KERNEL__ */
882 extern int lgss_handle_krb5_upcall(uid_t uid, __u32 dest_ip,
883                                    char *obd_name, char *buf, int bufsize,
884                                    int (*callback)(char*, unsigned long));
885
886 static int gss_cred_refresh(struct ptlrpc_cred *cred)
887 {
888         char                    buf[4096];
889         rawobj_t                obj;
890         struct obd_import      *imp;
891         struct gss_sec         *gsec;
892         struct gss_api_mech    *mech;
893         struct gss_cl_ctx      *ctx = NULL;
894         ptl_nid_t               peer_nid;
895         __u32                   dest_ip;
896         __u32                   subflavor;
897         int                     rc, gss_err;
898         struct gss_upcall_msg_data gmd = { 0 };
899
900         LASSERT(cred);
901         LASSERT(cred->pc_sec);
902         LASSERT(cred->pc_sec->ps_import);
903         LASSERT(cred->pc_sec->ps_import->imp_obd);
904
905         if (ptlrpcs_cred_is_uptodate(cred))
906                 RETURN(0);
907
908         imp = cred->pc_sec->ps_import;
909         peer_nid = imp->imp_connection->c_peer.peer_id.nid;
910         dest_ip = (__u32) (peer_nid & 0xFFFFFFFF);
911         subflavor = cred->pc_sec->ps_flavor;
912
913         if (subflavor != PTLRPCS_SUBFLVR_KRB5I) {
914                 CERROR("unknown subflavor %u\n", subflavor);
915                 GOTO(err_out, rc = -EINVAL);
916         }
917
918         rc = lgss_handle_krb5_upcall(cred->pc_uid, dest_ip,
919                                      imp->imp_obd->obd_name,
920                                      buf, sizeof(buf),
921                                      gss_send_secinit_rpc);
922         LASSERT(rc != 0);
923         if (rc < 0)
924                 goto err_out;
925
926         obj.data = buf;
927         obj.len = rc;
928
929         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
930         mech = gsec->gs_mech;
931         LASSERT(mech);
932
933         rc = gss_parse_init_downcall(mech, &obj, &ctx, &gmd,
934                                      &gss_err);
935         if (rc || gss_err) {
936                 CERROR("parse init downcall: rpc %d, gss 0x%x\n", rc, gss_err);
937                 if (rc != -ERESTART || gss_err != 0)
938                         set_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags);
939                 if (rc == 0)
940                         rc = -EPERM;
941                 goto err_out;
942         }
943
944         LASSERT(ctx);
945         gss_cred_set_ctx(cred, ctx);
946         LASSERT(gss_cred_is_uptodate_ctx(cred));
947
948         return 0;
949 err_out:
950         set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
951         return rc;
952 }
953 #endif
954
955 static int gss_cred_match(struct ptlrpc_cred *cred,
956                           struct vfs_cred *vcred)
957 {
958         RETURN(cred->pc_pag == vcred->vc_pag);
959 }
960
961 static int gss_cred_sign(struct ptlrpc_cred *cred,
962                          struct ptlrpc_request *req)
963 {
964         struct gss_cred         *gcred;
965         struct gss_cl_ctx       *ctx;
966         rawobj_t                 lmsg, mic;
967         __u32                   *vp, *vpsave, vlen, seclen;
968         __u32                    seqnum, major, rc = 0;
969         ENTRY;
970
971         LASSERT(req->rq_reqbuf);
972         LASSERT(req->rq_cred == cred);
973
974         gcred = container_of(cred, struct gss_cred, gc_base);
975         ctx = gss_cred_get_ctx(cred);
976         if (!ctx) {
977                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
978                         cred, cred->pc_pag, cred->pc_uid);
979                 RETURN(-EPERM);
980         }
981
982         lmsg.len = req->rq_reqlen;
983         lmsg.data = (__u8 *) req->rq_reqmsg;
984
985         vp = (__u32 *) (lmsg.data + lmsg.len);
986         vlen = req->rq_reqbuf_len - sizeof(struct ptlrpcs_wire_hdr) -
987                lmsg.len;
988         seclen = vlen;
989
990         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
991                 CERROR("vlen %d, need %d\n",
992                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
993                 rc = -EIO;
994                 goto out;
995         }
996
997         spin_lock(&ctx->gc_seq_lock);
998         seqnum = ctx->gc_seq++;
999         spin_unlock(&ctx->gc_seq_lock);
1000
1001         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
1002         *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);        /* subflavor */
1003         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
1004         *vp++ = cpu_to_le32(seqnum);                    /* seq */
1005         *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_INTEGRITY); /* service */
1006         vlen -= 5 * 4;
1007
1008         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
1009                 rc = -EIO;
1010                 goto out;
1011         }
1012         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
1013
1014         vpsave = vp++;  /* reserve for size */
1015         vlen -= 4;
1016
1017         mic.len = vlen;
1018         mic.data = (unsigned char *)vp;
1019
1020         CDEBUG(D_SEC, "reqbuf at %p, lmsg at %p, len %d, mic at %p, len %d\n",
1021                req->rq_reqbuf, lmsg.data, lmsg.len, mic.data, mic.len);
1022         major = kgss_get_mic(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT, &lmsg, &mic);
1023         if (major) {
1024                 CERROR("cred %p: req %p compute mic error, major %x\n",
1025                        cred, req, major);
1026                 rc = -EACCES;
1027                 goto out;
1028         }
1029
1030         *vpsave = cpu_to_le32(mic.len);
1031         
1032         seclen = seclen - vlen + mic.len;
1033         buf_to_sec_hdr(req->rq_reqbuf)->sec_len = cpu_to_le32(seclen);
1034         req->rq_reqdata_len += size_round(seclen);
1035         CDEBUG(D_SEC, "msg size %d, checksum size %d, total sec size %d\n",
1036                lmsg.len, mic.len, seclen);
1037 out:
1038         gss_put_ctx(ctx);
1039         RETURN(rc);
1040 }
1041
1042 static int gss_cred_verify(struct ptlrpc_cred *cred,
1043                            struct ptlrpc_request *req)
1044 {
1045         struct gss_cred        *gcred;
1046         struct gss_cl_ctx      *ctx;
1047         struct ptlrpcs_wire_hdr *sec_hdr;
1048         rawobj_t                lmsg, mic;
1049         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1050         __u32                   major, minor, rc;
1051         ENTRY;
1052
1053         LASSERT(req->rq_repbuf);
1054         LASSERT(req->rq_cred == cred);
1055
1056         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1057         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr) + sec_hdr->msg_len);
1058         vlen = sec_hdr->sec_len;
1059
1060         if (vlen < 7 * 4) {
1061                 CERROR("reply sec size %u too small\n", vlen);
1062                 RETURN(-EPROTO);
1063         }
1064
1065         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1066                 CERROR("reply have different gss version\n");
1067                 RETURN(-EPROTO);
1068         }
1069         subflavor = le32_to_cpu(*vp++);
1070         proc = le32_to_cpu(*vp++);
1071         vlen -= 3 * 4;
1072
1073         switch (proc) {
1074         case PTLRPCS_GSS_PROC_DATA:
1075                 seq = le32_to_cpu(*vp++);
1076                 svc = le32_to_cpu(*vp++);
1077                 if (svc != PTLRPCS_GSS_SVC_INTEGRITY) {
1078                         CERROR("Unknown svc %d\n", svc);
1079                         RETURN(-EPROTO);
1080                 }
1081                 if (*vp++ != 0) {
1082                         CERROR("Unexpected ctx handle\n");
1083                         RETURN(-EPROTO);
1084                 }
1085                 mic.len = le32_to_cpu(*vp++);
1086                 vlen -= 4 * 4;
1087                 if (vlen < mic.len) {
1088                         CERROR("vlen %d, mic.len %d\n", vlen, mic.len);
1089                         RETURN(-EINVAL);
1090                 }
1091                 mic.data = (unsigned char *)vp;
1092
1093                 gcred = container_of(cred, struct gss_cred, gc_base);
1094                 ctx = gss_cred_get_ctx(cred);
1095                 LASSERT(ctx);
1096
1097                 lmsg.len = sec_hdr->msg_len;
1098                 lmsg.data = (__u8 *) buf_to_lustre_msg(req->rq_repbuf);
1099
1100                 major = kgss_verify_mic(ctx->gc_gss_ctx, &lmsg, &mic, NULL);
1101                 if (major != GSS_S_COMPLETE) {
1102                         CERROR("cred %p: req %p verify mic error: major %x\n",
1103                                cred, req, major);
1104
1105                         if (major == GSS_S_CREDENTIALS_EXPIRED ||
1106                             major == GSS_S_CONTEXT_EXPIRED) {
1107                                 ptlrpcs_cred_expire(cred);
1108                                 req->rq_ptlrpcs_restart = 1;
1109                                 rc = 0;
1110                         } else
1111                                 rc = -EINVAL;
1112
1113                         GOTO(proc_data_out, rc);
1114                 }
1115
1116                 req->rq_repmsg = (struct lustre_msg *) lmsg.data;
1117                 req->rq_replen = lmsg.len;
1118
1119                 /* here we could check the seq number is the same one
1120                  * we sent to server. but portals has prevent us from
1121                  * replay attack, so maybe we don't need check it again.
1122                  */
1123                 rc = 0;
1124 proc_data_out:
1125                 gss_put_ctx(ctx);
1126                 break;
1127         case PTLRPCS_GSS_PROC_ERR:
1128                 major = le32_to_cpu(*vp++);
1129                 minor = le32_to_cpu(*vp++);
1130                 /* server return NO_CONTEXT might be caused by context expire
1131                  * or server reboot/failover. we refresh the cred transparently
1132                  * to upper layer.
1133                  * In some cases, our gss handle is possible to be incidentally
1134                  * identical to another handle since the handle itself is not
1135                  * fully random. In krb5 case, the GSS_S_BAD_SIG will be
1136                  * returned, maybe other gss error for other mechanism. Here we
1137                  * only consider krb5 mech (FIXME) and try to establish new
1138                  * context.
1139                  */
1140                 if (major == GSS_S_NO_CONTEXT ||
1141                     major == GSS_S_BAD_SIG) {
1142                         CWARN("req %p: server report cred %p %s\n",
1143                                req, cred, (major == GSS_S_NO_CONTEXT) ?
1144                                            "NO_CONTEXT" : "BAD_SIG");
1145
1146                         ptlrpcs_cred_expire(cred);
1147                         req->rq_ptlrpcs_restart = 1;
1148                         rc = 0;
1149                 } else {
1150                         CERROR("req %p: unrecognized gss error (%x/%x)\n",
1151                                 req, major, minor);
1152                         rc = -EACCES;
1153                 }
1154                 break;
1155         default:
1156                 CERROR("unknown gss proc %d\n", proc);
1157                 rc = -EPROTO;
1158         }
1159
1160         RETURN(rc);
1161 }
1162
1163 static int gss_cred_seal(struct ptlrpc_cred *cred,
1164                          struct ptlrpc_request *req)
1165 {
1166         struct gss_cred         *gcred;
1167         struct gss_cl_ctx       *ctx;
1168         struct ptlrpcs_wire_hdr *sec_hdr;
1169         rawobj_buf_t             msg_buf;
1170         rawobj_t                 cipher_buf;
1171         __u32                   *vp, *vpsave, vlen, seclen;
1172         __u32                    major, seqnum, rc = 0;
1173         ENTRY;
1174
1175         LASSERT(req->rq_reqbuf);
1176         LASSERT(req->rq_cred == cred);
1177
1178         gcred = container_of(cred, struct gss_cred, gc_base);
1179         ctx = gss_cred_get_ctx(cred);
1180         if (!ctx) {
1181                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
1182                         cred, cred->pc_pag, cred->pc_uid);
1183                 RETURN(-EPERM);
1184         }
1185
1186         vp = (__u32 *) (req->rq_reqbuf + sizeof(*sec_hdr));
1187         vlen = req->rq_reqbuf_len - sizeof(*sec_hdr);
1188         seclen = vlen;
1189
1190         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
1191                 CERROR("vlen %d, need %d\n",
1192                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
1193                 rc = -EIO;
1194                 goto out;
1195         }
1196
1197         spin_lock(&ctx->gc_seq_lock);
1198         seqnum = ctx->gc_seq++;
1199         spin_unlock(&ctx->gc_seq_lock);
1200
1201         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
1202         *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5P);        /* subflavor */
1203         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
1204         *vp++ = cpu_to_le32(seqnum);                    /* seq */
1205         *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_PRIVACY);   /* service */
1206         vlen -= 5 * 4;
1207
1208         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
1209                 rc = -EIO;
1210                 goto out;
1211         }
1212         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
1213
1214         vpsave = vp++;  /* reserve for size */
1215         vlen -= 4;
1216
1217         msg_buf.buf = (__u8 *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1218         msg_buf.buflen = req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1219                                           GSS_PRIVBUF_SUFFIX_LEN;
1220         msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1221         msg_buf.datalen = req->rq_reqlen;
1222
1223         cipher_buf.data = (__u8 *) vp;
1224         cipher_buf.len = vlen;
1225
1226         major = kgss_wrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1227                           &msg_buf, &cipher_buf);
1228         if (major) {
1229                 CERROR("cred %p: error wrap: major 0x%x\n", cred, major);
1230                 GOTO(out, rc = -EINVAL);
1231         }
1232
1233         *vpsave = cpu_to_le32(cipher_buf.len);
1234
1235         seclen = seclen - vlen + cipher_buf.len;
1236         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1237         sec_hdr->sec_len = cpu_to_le32(seclen);
1238         req->rq_reqdata_len += size_round(seclen);
1239
1240         CDEBUG(D_SEC, "msg size %d, total sec size %d\n",
1241                req->rq_reqlen, seclen);
1242 out:
1243         gss_put_ctx(ctx);
1244         RETURN(rc);
1245 }
1246
1247 static int gss_cred_unseal(struct ptlrpc_cred *cred,
1248                            struct ptlrpc_request *req)
1249 {
1250         struct gss_cred        *gcred;
1251         struct gss_cl_ctx      *ctx;
1252         struct ptlrpcs_wire_hdr *sec_hdr;
1253         rawobj_t                cipher_text, plain_text;
1254         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1255         __u32                   major, rc;
1256         ENTRY;
1257
1258         LASSERT(req->rq_repbuf);
1259         LASSERT(req->rq_cred == cred);
1260
1261         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1262         if (sec_hdr->msg_len != 0) {
1263                 CERROR("unexpected msg_len %u\n", sec_hdr->msg_len);
1264                 RETURN(-EPROTO);
1265         }
1266
1267         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr));
1268         vlen = sec_hdr->sec_len;
1269
1270         if (vlen < 7 * 4) {
1271                 CERROR("reply sec size %u too small\n", vlen);
1272                 RETURN(-EPROTO);
1273         }
1274
1275         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1276                 CERROR("reply have different gss version\n");
1277                 RETURN(-EPROTO);
1278         }
1279         subflavor = le32_to_cpu(*vp++);
1280         proc = le32_to_cpu(*vp++);
1281         seq = le32_to_cpu(*vp++);
1282         svc = le32_to_cpu(*vp++);
1283         vlen -= 5 * 4;
1284
1285         switch (proc) {
1286         case PTLRPCS_GSS_PROC_DATA:
1287                 if (svc != PTLRPCS_GSS_SVC_PRIVACY) {
1288                         CERROR("Unknown svc %d\n", svc);
1289                         RETURN(-EPROTO);
1290                 }
1291                 if (*vp++ != 0) {
1292                         CERROR("Unexpected ctx handle\n");
1293                         RETURN(-EPROTO);
1294                 }
1295                 vlen -= 4;
1296
1297                 cipher_text.len = le32_to_cpu(*vp++);
1298                 cipher_text.data = (__u8 *) vp;
1299                 vlen -= 4;
1300
1301                 if (vlen < cipher_text.len) {
1302                         CERROR("cipher text to be %u while buf only %u\n",
1303                                 cipher_text.len, vlen);
1304                         RETURN(-EPROTO);
1305                 }
1306
1307                 plain_text = cipher_text;
1308
1309                 gcred = container_of(cred, struct gss_cred, gc_base);
1310                 ctx = gss_cred_get_ctx(cred);
1311                 LASSERT(ctx);
1312
1313                 major = kgss_unwrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1314                                     &cipher_text, &plain_text);
1315                 if (major) {
1316                         CERROR("cred %p: error unwrap: major 0x%x\n",
1317                                cred, major);
1318
1319                         if (major == GSS_S_CREDENTIALS_EXPIRED ||
1320                             major == GSS_S_CONTEXT_EXPIRED) {
1321                                 ptlrpcs_cred_expire(cred);
1322                                 req->rq_ptlrpcs_restart = 1;
1323                                 rc = 0;
1324                         } else
1325                                 rc = -EINVAL;
1326
1327                         GOTO(proc_out, rc);
1328                 }
1329
1330                 req->rq_repmsg = (struct lustre_msg *) vp;
1331                 req->rq_replen = plain_text.len;
1332
1333                 rc = 0;
1334 proc_out:
1335                 gss_put_ctx(ctx);
1336                 break;
1337         default:
1338                 CERROR("unknown gss proc %d\n", proc);
1339                 rc = -EPROTO;
1340         }
1341
1342         RETURN(rc);
1343 }
1344
1345 static void destroy_gss_context(struct ptlrpc_cred *cred)
1346 {
1347         struct ptlrpcs_wire_hdr *hdr;
1348         struct lustre_msg       *lmsg;
1349         struct gss_cred         *gcred;
1350         struct ptlrpc_request    req;
1351         struct obd_import       *imp;
1352         __u32                   *vp, lmsg_size;
1353         struct ptlrpc_request   *raw_req = NULL;
1354         const int                repbuf_len = 256;
1355         char                    *repbuf;
1356         int                      replen, rc;
1357         ENTRY;
1358
1359         imp = cred->pc_sec->ps_import;
1360         LASSERT(imp);
1361
1362         if (test_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags) ||
1363             !test_bit(PTLRPC_CRED_UPTODATE_BIT, &cred->pc_flags)) {
1364                 CDEBUG(D_SEC, "Destroy dead cred %p(%u@%s)\n",
1365                        cred, cred->pc_uid, imp->imp_target_uuid.uuid);
1366                 EXIT;
1367                 return;
1368         }
1369
1370         might_sleep();
1371
1372         /* cred's refcount is 0, steal one */
1373         atomic_inc(&cred->pc_refcount);
1374
1375         gcred = container_of(cred, struct gss_cred, gc_base);
1376         gcred->gc_ctx->gc_proc = PTLRPCS_GSS_PROC_DESTROY;
1377
1378         CDEBUG(D_SEC, "client destroy gss cred %p(%u@%s)\n",
1379                gcred, cred->pc_uid, imp->imp_target_uuid.uuid);
1380
1381         lmsg_size = lustre_msg_size(0, NULL);
1382         req.rq_req_secflvr = cred->pc_sec->ps_flavor;
1383         req.rq_cred = cred;
1384         req.rq_reqbuf_len = sizeof(*hdr) + lmsg_size +
1385                             ptlrpcs_est_req_payload(&req, lmsg_size);
1386
1387         OBD_ALLOC(req.rq_reqbuf, req.rq_reqbuf_len);
1388         if (!req.rq_reqbuf) {
1389                 CERROR("Fail to alloc reqbuf, cancel anyway\n");
1390                 atomic_dec(&cred->pc_refcount);
1391                 EXIT;
1392                 return;
1393         }
1394
1395         /* wire hdr */
1396         hdr = buf_to_sec_hdr(req.rq_reqbuf);
1397         hdr->flavor  = cpu_to_le32(PTLRPCS_FLVR_GSS_AUTH);
1398         hdr->msg_len = cpu_to_le32(lmsg_size);
1399         hdr->sec_len = cpu_to_le32(0);
1400
1401         /* lustre message */
1402         lmsg = buf_to_lustre_msg(req.rq_reqbuf);
1403         lustre_init_msg(lmsg, 0, NULL, NULL);
1404         lmsg->handle   = imp->imp_remote_handle;
1405         lmsg->type     = PTL_RPC_MSG_REQUEST;
1406         lmsg->opc      = SEC_FINI;
1407         lmsg->flags    = 0;
1408         lmsg->conn_cnt = imp->imp_conn_cnt;
1409         /* add this for randomize */
1410         get_random_bytes(&lmsg->last_xid, sizeof(lmsg->last_xid));
1411         get_random_bytes(&lmsg->transno, sizeof(lmsg->transno));
1412
1413         vp = (__u32 *) req.rq_reqbuf;
1414
1415         req.rq_cred = cred;
1416         req.rq_reqmsg = buf_to_lustre_msg(req.rq_reqbuf);
1417         req.rq_reqlen = lmsg_size;
1418         req.rq_reqdata_len = sizeof(*hdr) + lmsg_size;
1419
1420         if (gss_cred_sign(cred, &req)) {
1421                 CERROR("failed to sign, cancel anyway\n");
1422                 atomic_dec(&cred->pc_refcount);
1423                 goto exit;
1424         }
1425         atomic_dec(&cred->pc_refcount);
1426
1427         OBD_ALLOC(repbuf, repbuf_len);
1428         if (!repbuf)
1429                 goto exit;
1430
1431         raw_req = ptl_do_rawrpc(imp, req.rq_reqbuf, req.rq_reqbuf_len,
1432                                 req.rq_reqdata_len, repbuf, repbuf_len, &replen,
1433                                 SECFINI_RPC_TIMEOUT, &rc);
1434         if (!raw_req)
1435                 OBD_FREE(repbuf, repbuf_len);
1436
1437 exit:
1438         if (raw_req == NULL)
1439                 OBD_FREE(req.rq_reqbuf, req.rq_reqbuf_len);
1440         else
1441                 rawrpc_req_finished(raw_req);
1442         EXIT;
1443 }
1444
1445 static void gss_cred_destroy(struct ptlrpc_cred *cred)
1446 {
1447         struct gss_cred *gcred;
1448         ENTRY;
1449
1450         LASSERT(cred);
1451         LASSERT(!atomic_read(&cred->pc_refcount));
1452
1453         gcred = container_of(cred, struct gss_cred, gc_base);
1454         if (gcred->gc_ctx) {
1455                 destroy_gss_context(cred);
1456                 gss_put_ctx(gcred->gc_ctx);
1457         }
1458
1459         CDEBUG(D_SEC, "sec.gss %p: destroy cred %p\n", cred->pc_sec, gcred);
1460
1461         OBD_FREE(gcred, sizeof(*gcred));
1462         EXIT;
1463 }
1464
1465 static struct ptlrpc_credops gss_credops = {
1466         .refresh        = gss_cred_refresh,
1467         .match          = gss_cred_match,
1468         .sign           = gss_cred_sign,
1469         .verify         = gss_cred_verify,
1470         .seal           = gss_cred_seal,
1471         .unseal         = gss_cred_unseal,
1472         .destroy        = gss_cred_destroy,
1473 };
1474
1475 #ifdef __KERNEL__
1476 /*******************************************
1477  * rpc_pipe APIs                           *
1478  *******************************************/
1479 static ssize_t
1480 gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
1481                 char *dst, size_t buflen)
1482 {
1483         char *data = (char *)msg->data + msg->copied;
1484         ssize_t mlen = msg->len;
1485         ssize_t left;
1486         ENTRY;
1487
1488         if (mlen > buflen)
1489                 mlen = buflen;
1490         left = copy_to_user(dst, data, mlen);
1491         if (left < 0) {
1492                 msg->errno = left;
1493                 RETURN(left);
1494         }
1495         mlen -= left;
1496         msg->copied += mlen;
1497         msg->errno = 0;
1498         RETURN(mlen);
1499 }
1500
1501 static ssize_t
1502 gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
1503 {
1504         char *buf;
1505         const int bufsize = 1024;
1506         rawobj_t obj;
1507         struct inode *inode = filp->f_dentry->d_inode;
1508         struct rpc_inode *rpci = RPC_I(inode);
1509         struct obd_import *import;
1510         struct ptlrpc_sec *sec;
1511         struct gss_sec *gsec;
1512         char *obdname;
1513         struct gss_api_mech *mech;
1514         struct vfs_cred vcred = { 0 };
1515         struct ptlrpc_cred *cred;
1516         struct gss_upcall_msg *gss_msg;
1517         struct gss_upcall_msg_data gmd = { 0 };
1518         struct gss_cl_ctx *ctx = NULL;
1519         ssize_t left;
1520         int err, gss_err;
1521         ENTRY;
1522
1523         if (mlen > bufsize) {
1524                 CERROR("mlen %ld > bufsize %d\n", (long)mlen, bufsize);
1525                 RETURN(-ENOSPC);
1526         }
1527
1528         OBD_ALLOC(buf, bufsize);
1529         if (!buf) {
1530                 CERROR("alloc mem failed\n");
1531                 RETURN(-ENOMEM);
1532         }
1533
1534         left = copy_from_user(buf, src, mlen);
1535         if (left)
1536                 GOTO(err_free, err = -EFAULT);
1537
1538         obj.data = (unsigned char *)buf;
1539         obj.len = mlen;
1540
1541         LASSERT(rpci->private);
1542         gsec = (struct gss_sec *)rpci->private;
1543         sec = &gsec->gs_base;
1544         LASSERT(sec->ps_import);
1545         import = class_import_get(sec->ps_import);
1546         LASSERT(import->imp_obd);
1547         obdname = import->imp_obd->obd_name;
1548         mech = gsec->gs_mech;
1549
1550         err = gss_parse_init_downcall(mech, &obj, &ctx, &gmd, &gss_err);
1551         if (err)
1552                 CERROR("parse init downcall err %d\n", err);
1553
1554         vcred.vc_pag = gmd.gum_pag;
1555         vcred.vc_uid = gmd.gum_uid;
1556
1557         cred = ptlrpcs_cred_lookup(sec, &vcred);
1558         if (!cred) {
1559                 CWARN("didn't find cred for uid %u\n", vcred.vc_uid);
1560                 GOTO(err, err = -EINVAL);
1561         }
1562
1563         if (err || gss_err) {
1564                 set_bit(PTLRPC_CRED_DEAD_BIT, &cred->pc_flags);
1565                 if (err != -ERESTART || gss_err != 0)
1566                         set_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags);
1567                 CERROR("cred %p: rpc err %d, gss err 0x%x, fatal %d\n",
1568                        cred, err, gss_err,
1569                        (test_bit(PTLRPC_CRED_ERROR_BIT, &cred->pc_flags) != 0));
1570         } else {
1571                 CDEBUG(D_SEC, "get initial ctx:\n");
1572                 gss_cred_set_ctx(cred, ctx);
1573         }
1574
1575         spin_lock(&gsec->gs_lock);
1576         gss_msg = gss_find_upcall(gsec, obdname, &gmd);
1577         if (gss_msg) {
1578                 gss_unhash_msg_nolock(gss_msg);
1579                 spin_unlock(&gsec->gs_lock);
1580                 gss_release_msg(gss_msg);
1581         } else
1582                 spin_unlock(&gsec->gs_lock);
1583
1584         ptlrpcs_cred_put(cred, 1);
1585         class_import_put(import);
1586         OBD_FREE(buf, bufsize);
1587         RETURN(mlen);
1588 err:
1589         if (ctx)
1590                 gss_destroy_ctx(ctx);
1591         class_import_put(import);
1592 err_free:
1593         OBD_FREE(buf, bufsize);
1594         CDEBUG(D_SEC, "gss_pipe_downcall returning %d\n", err);
1595         RETURN(err);
1596 }
1597
1598 static
1599 void gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
1600 {
1601         struct gss_upcall_msg *gmsg;
1602         static unsigned long ratelimit;
1603         ENTRY;
1604
1605         LASSERT(list_empty(&msg->list));
1606
1607         if (msg->errno >= 0) {
1608                 EXIT;
1609                 return;
1610         }
1611
1612         gmsg = container_of(msg, struct gss_upcall_msg, gum_base);
1613         CDEBUG(D_SEC, "destroy gmsg %p\n", gmsg);
1614         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
1615         atomic_inc(&gmsg->gum_refcount);
1616         gss_unhash_msg(gmsg);
1617         if (msg->errno == -ETIMEDOUT || msg->errno == -EPIPE) {
1618                 unsigned long now = get_seconds();
1619                 if (time_after(now, ratelimit)) {
1620                         CWARN("sec.gss upcall timed out.\n"
1621                               "Please check user daemon is running!\n");
1622                         ratelimit = now + 15;
1623                 }
1624         }
1625         gss_release_msg(gmsg);
1626         EXIT;
1627 }
1628
1629 static
1630 void gss_pipe_release(struct inode *inode)
1631 {
1632         struct rpc_inode *rpci = RPC_I(inode);
1633         struct ptlrpc_sec *sec;
1634         struct gss_sec *gsec;
1635         ENTRY;
1636
1637         gsec = (struct gss_sec *)rpci->private;
1638         sec = &gsec->gs_base;
1639         spin_lock(&gsec->gs_lock);
1640         while (!list_empty(&gsec->gs_upcalls)) {
1641                 struct gss_upcall_msg *gmsg;
1642
1643                 gmsg = list_entry(gsec->gs_upcalls.next,
1644                                   struct gss_upcall_msg, gum_list);
1645                 LASSERT(list_empty(&gmsg->gum_base.list));
1646                 gmsg->gum_base.errno = -EPIPE;
1647                 atomic_inc(&gmsg->gum_refcount);
1648                 gss_unhash_msg_nolock(gmsg);
1649                 gss_release_msg(gmsg);
1650         }
1651         spin_unlock(&gsec->gs_lock);
1652         EXIT;
1653 }
1654
1655 static struct rpc_pipe_ops gss_upcall_ops = {
1656         .upcall         = gss_pipe_upcall,
1657         .downcall       = gss_pipe_downcall,
1658         .destroy_msg    = gss_pipe_destroy_msg,
1659         .release_pipe   = gss_pipe_release,
1660 };
1661 #endif /* __KERNEL__ */
1662
1663 /*********************************************
1664  * GSS security APIs                         *
1665  *********************************************/
1666
1667 static
1668 struct ptlrpc_sec* gss_create_sec(__u32 flavor,
1669                                   const char *pipe_dir,
1670                                   void *pipe_data)
1671 {
1672         struct gss_sec *gsec;
1673         struct ptlrpc_sec *sec;
1674         uid_t save_uid;
1675
1676 #ifdef __KERNEL__
1677         char *pos;
1678         int   pipepath_len;
1679 #endif
1680         ENTRY;
1681
1682         LASSERT(SEC_FLAVOR_MAJOR(flavor) == PTLRPCS_FLVR_MAJOR_GSS);
1683
1684         OBD_ALLOC(gsec, sizeof(*gsec));
1685         if (!gsec) {
1686                 CERROR("can't alloc gsec\n");
1687                 RETURN(NULL);
1688         }
1689
1690         gsec->gs_mech = kgss_subflavor_to_mech(SEC_FLAVOR_SUB(flavor));
1691         if (!gsec->gs_mech) {
1692                 CERROR("subflavor 0x%x not found\n", flavor);
1693                 goto err_free;
1694         }
1695
1696         /* initialize gss sec */
1697 #ifdef __KERNEL__
1698         INIT_LIST_HEAD(&gsec->gs_upcalls);
1699         spin_lock_init(&gsec->gs_lock);
1700
1701         pipepath_len = strlen(LUSTRE_PIPEDIR) + strlen(pipe_dir) +
1702                        strlen(gsec->gs_mech->gm_name) + 3;
1703         OBD_ALLOC(gsec->gs_pipepath, pipepath_len);
1704         if (!gsec->gs_pipepath)
1705                 goto err_mech_put;
1706
1707         /* pipe rpc require root permission */
1708         save_uid = current->fsuid;
1709         current->fsuid = 0;
1710
1711         sprintf(gsec->gs_pipepath, LUSTRE_PIPEDIR"/%s", pipe_dir);
1712         if (IS_ERR(rpc_mkdir(gsec->gs_pipepath, NULL))) {
1713                 CERROR("can't make pipedir %s\n", gsec->gs_pipepath);
1714                 goto err_free_path;
1715         }
1716
1717         sprintf(gsec->gs_pipepath, LUSTRE_PIPEDIR"/%s/%s", pipe_dir,
1718                 gsec->gs_mech->gm_name); 
1719         gsec->gs_depipe = rpc_mkpipe(gsec->gs_pipepath, gsec,
1720                                      &gss_upcall_ops, RPC_PIPE_WAIT_FOR_OPEN);
1721         if (IS_ERR(gsec->gs_depipe)) {
1722                 CERROR("failed to make rpc_pipe %s: %ld\n",
1723                         gsec->gs_pipepath, PTR_ERR(gsec->gs_depipe));
1724                 goto err_rmdir;
1725         }
1726         CDEBUG(D_SEC, "gss sec %p, pipe path %s\n", gsec, gsec->gs_pipepath);
1727 #endif
1728
1729         sec = &gsec->gs_base;
1730         sec->ps_expire = GSS_CREDCACHE_EXPIRE;
1731         sec->ps_nextgc = get_seconds() + sec->ps_expire;
1732         sec->ps_flags = 0;
1733
1734         current->fsuid = save_uid;
1735
1736         CDEBUG(D_SEC, "Create sec.gss %p\n", gsec);
1737         RETURN(sec);
1738
1739 #ifdef __KERNEL__
1740 err_rmdir:
1741         pos = strrchr(gsec->gs_pipepath, '/');
1742         LASSERT(pos);
1743         *pos = 0;
1744         rpc_rmdir(gsec->gs_pipepath);
1745 err_free_path:
1746         current->fsuid = save_uid;
1747         OBD_FREE(gsec->gs_pipepath, pipepath_len);
1748 err_mech_put:
1749 #endif
1750         kgss_mech_put(gsec->gs_mech);
1751 err_free:
1752         OBD_FREE(gsec, sizeof(*gsec));
1753         RETURN(NULL);
1754 }
1755
1756 static
1757 void gss_destroy_sec(struct ptlrpc_sec *sec)
1758 {
1759         struct gss_sec *gsec;
1760 #ifdef __KERNEL__
1761         char *pos;
1762         int   pipepath_len;
1763 #endif
1764         ENTRY;
1765
1766         gsec = container_of(sec, struct gss_sec, gs_base);
1767         CDEBUG(D_SEC, "Destroy sec.gss %p\n", gsec);
1768
1769         LASSERT(gsec->gs_mech);
1770         LASSERT(!atomic_read(&sec->ps_refcount));
1771         LASSERT(!atomic_read(&sec->ps_credcount));
1772 #ifdef __KERNEL__
1773         pipepath_len = strlen(gsec->gs_pipepath) + 1;
1774         rpc_unlink(gsec->gs_pipepath);
1775         pos = strrchr(gsec->gs_pipepath, '/');
1776         LASSERT(pos);
1777         *pos = 0;
1778         rpc_rmdir(gsec->gs_pipepath);
1779         OBD_FREE(gsec->gs_pipepath, pipepath_len);
1780 #endif
1781
1782         kgss_mech_put(gsec->gs_mech);
1783         OBD_FREE(gsec, sizeof(*gsec));
1784         EXIT;
1785 }
1786
1787 static
1788 struct ptlrpc_cred * gss_create_cred(struct ptlrpc_sec *sec,
1789                                      struct vfs_cred *vcred)
1790 {
1791         struct gss_cred *gcred;
1792         struct ptlrpc_cred *cred;
1793         ENTRY;
1794
1795         OBD_ALLOC(gcred, sizeof(*gcred));
1796         if (!gcred)
1797                 RETURN(NULL);
1798
1799         cred = &gcred->gc_base;
1800         INIT_LIST_HEAD(&cred->pc_hash);
1801         atomic_set(&cred->pc_refcount, 0);
1802         cred->pc_sec = sec;
1803         cred->pc_ops = &gss_credops;
1804         cred->pc_expire = 0;
1805         cred->pc_flags = 0;
1806         cred->pc_pag = vcred->vc_pag;
1807         cred->pc_uid = vcred->vc_uid;
1808         CDEBUG(D_SEC, "create a gss cred at %p("LPU64"/%u)\n",
1809                cred, vcred->vc_pag, vcred->vc_uid);
1810
1811         RETURN(cred);
1812 }
1813
1814 static int gss_estimate_payload(struct ptlrpc_sec *sec,
1815                                 struct ptlrpc_request *req,
1816                                 int msgsize)
1817 {
1818         switch (SEC_FLAVOR_SVC(req->rq_req_secflvr)) {
1819         case PTLRPCS_SVC_AUTH:
1820                 return GSS_MAX_AUTH_PAYLOAD;
1821         case PTLRPCS_SVC_PRIV:
1822                 return size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1823                                     GSS_PRIVBUF_PREFIX_LEN +
1824                                     GSS_PRIVBUF_SUFFIX_LEN);
1825         default:
1826                 LBUG();
1827                 return 0;
1828         }
1829 }
1830
1831 static int gss_alloc_reqbuf(struct ptlrpc_sec *sec,
1832                             struct ptlrpc_request *req,
1833                             int lmsg_size)
1834 {
1835         int msg_payload, sec_payload;
1836         int privacy, rc;
1837         ENTRY;
1838
1839         /* In PRIVACY mode, lustre message is always 0 (already encoded into
1840          * security payload).
1841          */
1842         privacy = (SEC_FLAVOR_SVC(req->rq_req_secflvr) == PTLRPCS_SVC_PRIV);
1843         msg_payload = privacy ? 0 : lmsg_size;
1844         sec_payload = gss_estimate_payload(sec, req, lmsg_size);
1845
1846         rc = sec_alloc_reqbuf(sec, req, msg_payload, sec_payload);
1847         if (rc)
1848                 return rc;
1849
1850         if (privacy) {
1851                 int buflen = lmsg_size + GSS_PRIVBUF_PREFIX_LEN +
1852                              GSS_PRIVBUF_SUFFIX_LEN;
1853                 char *buf;
1854
1855                 OBD_ALLOC(buf, buflen);
1856                 if (!buf) {
1857                         CERROR("Fail to alloc %d\n", buflen);
1858                         sec_free_reqbuf(sec, req);
1859                         RETURN(-ENOMEM);
1860                 }
1861                 req->rq_reqmsg = (struct lustre_msg *)
1862                                         (buf + GSS_PRIVBUF_PREFIX_LEN);
1863         }
1864
1865         RETURN(0);
1866 }
1867
1868 static void gss_free_reqbuf(struct ptlrpc_sec *sec,
1869                             struct ptlrpc_request *req)
1870 {
1871         char *buf;
1872         int privacy;
1873         ENTRY;
1874
1875         LASSERT(req->rq_reqmsg);
1876         LASSERT(req->rq_reqlen);
1877
1878         privacy = SEC_FLAVOR_SVC(req->rq_req_secflvr) == PTLRPCS_SVC_PRIV;
1879         if (privacy) {
1880                 buf = (char *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1881                 LASSERT(buf < req->rq_reqbuf ||
1882                         buf >= req->rq_reqbuf + req->rq_reqbuf_len);
1883                 OBD_FREE(buf, req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1884                               GSS_PRIVBUF_SUFFIX_LEN);
1885                 req->rq_reqmsg = NULL;
1886         }
1887
1888         sec_free_reqbuf(sec, req);
1889 }
1890
1891 static struct ptlrpc_secops gss_secops = {
1892         .create_sec             = gss_create_sec,
1893         .destroy_sec            = gss_destroy_sec,
1894         .create_cred            = gss_create_cred,
1895         .est_req_payload        = gss_estimate_payload,
1896         .est_rep_payload        = gss_estimate_payload,
1897         .alloc_reqbuf           = gss_alloc_reqbuf,
1898         .free_reqbuf            = gss_free_reqbuf,
1899 };
1900
1901 static struct ptlrpc_sec_type gss_type = {
1902         .pst_owner      = THIS_MODULE,
1903         .pst_name       = "sec.gss",
1904         .pst_inst       = ATOMIC_INIT(0),
1905         .pst_flavor     = PTLRPCS_FLVR_MAJOR_GSS,
1906         .pst_ops        = &gss_secops,
1907 };
1908
1909 extern int
1910 (*lustre_secinit_downcall_handler)(char *buffer, unsigned long count);
1911
1912 int __init ptlrpcs_gss_init(void)
1913 {
1914         int rc;
1915
1916         rc = ptlrpcs_register(&gss_type);
1917         if (rc)
1918                 return rc;
1919
1920 #ifdef __KERNEL__
1921         gss_svc_init();
1922
1923         rc = PTR_ERR(rpc_mkdir(LUSTRE_PIPEDIR, NULL));
1924         if (IS_ERR((void *)rc) && rc != -EEXIST) {
1925                 CERROR("fail to make rpcpipedir for lustre\n");
1926                 gss_svc_exit();
1927                 ptlrpcs_unregister(&gss_type);
1928                 return -1;
1929         }
1930         rc = 0;
1931 #else
1932 #endif
1933         rc = init_kerberos_module();
1934         if (rc) {
1935                 ptlrpcs_unregister(&gss_type);
1936         }
1937
1938         lustre_secinit_downcall_handler = gss_send_secinit_rpc;
1939
1940         return rc;
1941 }
1942
1943 #ifdef __KERNEL__
1944 static void __exit ptlrpcs_gss_exit(void)
1945 {
1946         lustre_secinit_downcall_handler = NULL;
1947
1948         cleanup_kerberos_module();
1949         rpc_rmdir(LUSTRE_PIPEDIR);
1950         gss_svc_exit();
1951         ptlrpcs_unregister(&gss_type);
1952 }
1953 #endif
1954
1955 MODULE_AUTHOR("Cluster File Systems, Inc. <info@clusterfs.com>");
1956 MODULE_DESCRIPTION("GSS Security module for Lustre");
1957 MODULE_LICENSE("GPL");
1958
1959 module_init(ptlrpcs_gss_init);
1960 module_exit(ptlrpcs_gss_exit);