Whamcloud - gitweb
- compile fixes for --enable-liblustre
[fs/lustre-release.git] / lustre / sec / gss / sec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * linux/net/sunrpc/auth_gss.c
12  *
13  * RPCSEC_GSS client authentication.
14  *
15  *  Copyright (c) 2000 The Regents of the University of Michigan.
16  *  All rights reserved.
17  *
18  *  Dug Song       <dugsong@monkey.org>
19  *  Andy Adamson   <andros@umich.edu>
20  *
21  *  Redistribution and use in source and binary forms, with or without
22  *  modification, are permitted provided that the following conditions
23  *  are met:
24  *
25  *  1. Redistributions of source code must retain the above copyright
26  *     notice, this list of conditions and the following disclaimer.
27  *  2. Redistributions in binary form must reproduce the above copyright
28  *     notice, this list of conditions and the following disclaimer in the
29  *     documentation and/or other materials provided with the distribution.
30  *  3. Neither the name of the University nor the names of its
31  *     contributors may be used to endorse or promote products derived
32  *     from this software without specific prior written permission.
33  *
34  *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
35  *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
36  *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
37  *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
38  *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
39  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
40  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
41  *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
42  *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
43  *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
44  *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45  *
46  * $Id: sec_gss.c,v 1.4 2005/04/13 09:49:50 yury Exp $
47  */
48
49 #ifndef EXPORT_SYMTAB
50 # define EXPORT_SYMTAB
51 #endif
52 #define DEBUG_SUBSYSTEM S_SEC
53 #ifdef __KERNEL__
54 #include <linux/init.h>
55 #include <linux/module.h>
56 #include <linux/slab.h>
57 #include <linux/dcache.h>
58 #include <linux/fs.h>
59 #include <linux/random.h>
60 /* for rpc_pipefs */
61 struct rpc_clnt;
62 #include <linux/sunrpc/rpc_pipe_fs.h>
63 #else
64 #include <liblustre.h>
65 #endif
66
67 #include <libcfs/kp30.h>
68 #include <linux/obd.h>
69 #include <linux/obd_class.h>
70 #include <linux/obd_support.h>
71 #include <linux/lustre_idl.h>
72 #include <linux/lustre_net.h>
73 #include <linux/lustre_import.h>
74 #include <linux/lustre_sec.h>
75
76 #include "gss_err.h"
77 #include "gss_internal.h"
78 #include "gss_api.h"
79
80 #define GSS_CREDCACHE_EXPIRE    (60)               /* 1 minute */
81 #define GSS_CRED_EXPIRE         (8 * 60 * 60)      /* 8 hours */
82 #define GSS_CRED_SIGN_SIZE      (1024)
83 #define GSS_CRED_VERIFY_SIZE    (56)
84
85 #define LUSTRE_PIPEDIR          "/lustre"
86
87 /**********************************************
88  * gss security init/fini helper              *
89  **********************************************/
90
91 #define SECINIT_RPC_TIMEOUT     (10)
92 #define SECFINI_RPC_TIMEOUT     (10)
93
94 static int secinit_compose_request(struct obd_import *imp,
95                                    char *buf, int bufsize,
96                                    char __user *token)
97 {
98         struct ptlrpcs_wire_hdr *hdr;
99         struct lustre_msg       *lmsg;
100         char __user             *token_buf;
101         __u64                    token_size;
102         __u32                    lmsg_size, *p;
103         int rc;
104
105         lmsg_size = lustre_msg_size(0, NULL);
106
107         if (copy_from_user(&token_size, token, sizeof(token_size))) {
108                 CERROR("read token error\n");
109                 return -EFAULT;
110         }
111         if (sizeof(*hdr) + lmsg_size + size_round(token_size) > bufsize) {
112                 CERROR("token size "LPU64" too large\n", token_size);
113                 return -EINVAL;
114         }
115
116         if (copy_from_user(&token_buf, (token + sizeof(token_size)),
117                            sizeof(void*))) {
118                 CERROR("read token buf pointer error\n");
119                 return -EFAULT;
120         }
121
122         /* security wire hdr */
123         hdr = buf_to_sec_hdr(buf);
124         hdr->flavor  = cpu_to_le32(PTLRPC_SEC_GSS);
125         hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
126         hdr->msg_len = cpu_to_le32(lmsg_size);
127         hdr->sec_len = cpu_to_le32(7 * 4 + token_size);
128
129         /* lustre message */
130         lmsg = buf_to_lustre_msg(buf);
131         lustre_init_msg(lmsg, 0, NULL, NULL);
132         lmsg->handle   = imp->imp_remote_handle;
133         lmsg->type     = PTL_RPC_MSG_REQUEST;
134         lmsg->opc      = SEC_INIT;
135         lmsg->flags    = 0;
136         lmsg->conn_cnt = imp->imp_conn_cnt;
137
138         p = (__u32 *) (buf + sizeof(*hdr) + lmsg_size);
139
140         /* gss hdr */
141         *p++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);     /* gss version */
142         *p++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);       /* subflavor */
143         *p++ = cpu_to_le32(PTLRPC_GSS_PROC_INIT);       /* proc */
144         *p++ = cpu_to_le32(0);                          /* seq */
145         *p++ = cpu_to_le32(PTLRPC_GSS_SVC_NONE);        /* service */
146         *p++ = cpu_to_le32(0);                          /* context handle */
147
148         /* now the token part */
149         *p++ = (__u32)(cpu_to_le64(token_size));
150         LASSERT(((char *)p - buf) + token_size <= bufsize);
151
152         rc = copy_from_user(p, token_buf, token_size);
153         if (rc) {
154                 CERROR("can't copy token\n");
155                 return -EFAULT;
156         }
157
158         rc = size_round(((char *)p - buf) + token_size);
159         return rc;
160 }
161
162 static int secinit_parse_reply(char *repbuf, int replen,
163                                char __user *outbuf, int outlen)
164 {
165         __u32 *p = (__u32 *)repbuf;
166         __u32 lmsg_len, sec_len, status, major, minor, seq, obj_len, round_len;
167         __u32 effective = 0;
168
169         if (replen <= (4 + 6) * 4) {
170                 CERROR("reply size %d too small\n", replen);
171                 return -EINVAL;
172         }
173
174         lmsg_len = le32_to_cpu(p[2]);
175         sec_len = le32_to_cpu(p[3]);
176
177         /* sanity checks */
178         if (p[0] != cpu_to_le32(PTLRPC_SEC_GSS) ||
179             p[1] != cpu_to_le32(PTLRPC_SEC_TYPE_NONE)) {
180                 CERROR("unexpected reply\n");
181                 return -EINVAL;
182         }
183         if (lmsg_len % 8 ||
184             4 * 4 + lmsg_len + sec_len > replen) {
185                 CERROR("unexpected reply\n");
186                 return -EINVAL;
187         }
188         if (sec_len > outlen) {
189                 CERROR("outbuf too small\n");
190                 return -EINVAL;
191         }
192
193         p += 4;                 /* skip hdr */
194         p += lmsg_len / 4;      /* skip lmsg */
195         effective = 0;
196
197         status = le32_to_cpu(*p++);
198         major = le32_to_cpu(*p++);
199         minor = le32_to_cpu(*p++);
200         seq = le32_to_cpu(*p++);
201         effective += 4 * 4;
202
203         copy_to_user(outbuf, &status, 4);
204         outbuf += 4;
205         copy_to_user(outbuf, &major, 4);
206         outbuf += 4;
207         copy_to_user(outbuf, &minor, 4);
208         outbuf += 4;
209         copy_to_user(outbuf, &seq, 4);
210         outbuf += 4;
211
212         obj_len = le32_to_cpu(*p++);
213         round_len = (obj_len + 3) & ~ 3;
214         copy_to_user(outbuf, &obj_len, 4);
215         outbuf += 4;
216         copy_to_user(outbuf, (char *)p, round_len);
217         p += round_len / 4;
218         outbuf += round_len;
219         effective += 4 + round_len;
220
221         obj_len = le32_to_cpu(*p++);
222         round_len = (obj_len + 3) & ~ 3;
223         copy_to_user(outbuf, &obj_len, 4);
224         outbuf += 4;
225         copy_to_user(outbuf, (char *)p, round_len);
226         p += round_len / 4;
227         outbuf += round_len;
228         effective += 4 + round_len;
229
230         return effective;
231 }
232
233 /* input: 
234  *   1. ptr to uuid
235  *   2. ptr to send_token
236  *   3. ptr to output buffer
237  *   4. output buffer size
238  * output:
239  *   1. return code. 0 is success
240  *   2. no meaning
241  *   3. ptr output data
242  *   4. output data size
243  *
244  * return:
245  *   < 0: error
246  *   = 0: success
247  *
248  * FIXME This interface looks strange, should be reimplemented
249  */
250 static int gss_send_secinit_rpc(__user char *buffer, unsigned long count)
251 {
252         struct obd_import *imp;
253         const int reqbuf_size = 1024;
254         const int repbuf_size = 1024;
255         char *reqbuf, *repbuf;
256         struct obd_device *obd;
257         char obdname[64];
258         long inbuf[4], lsize;
259         int rc, reqlen, replen;
260
261         if (count != 4 * sizeof(long)) {
262                 CERROR("count %lu\n", count);
263                 RETURN(-EINVAL);
264         }
265         if (copy_from_user(inbuf, buffer, count)) {
266                 CERROR("Invalid pointer\n");
267                 RETURN(-EFAULT);
268         }
269
270         /* take name */
271         if (strncpy_from_user(obdname, (char *)inbuf[0],
272                               sizeof(obdname)) <= 0) {
273                 CERROR("Invalid obdname pointer\n");
274                 RETURN(-EFAULT);
275         }
276
277         obd = class_name2obd(obdname);
278         if (!obd) {
279                 CERROR("no such obd %s\n", obdname);
280                 RETURN(-EINVAL);
281         }
282         if (strcmp(obd->obd_type->typ_name, "mdc") &&
283             strcmp(obd->obd_type->typ_name, "osc")) {
284                 CERROR("%s not a mdc/osc device\n", obdname);
285                 RETURN(-EINVAL);
286         }
287
288         imp = class_import_get(obd->u.cli.cl_import);
289
290         OBD_ALLOC(reqbuf, reqbuf_size);
291         OBD_ALLOC(repbuf, reqbuf_size);
292
293         if (!reqbuf || !repbuf) {
294                 CERROR("Can't alloc buffer: %p/%p\n", reqbuf, repbuf);
295                 GOTO(out_free, rc = -ENOMEM);
296         }
297
298         /* get token */
299         reqlen = secinit_compose_request(imp, reqbuf, reqbuf_size,
300                                          (char *)inbuf[1]);
301         if (reqlen < 0)
302                 GOTO(out_free, rc = reqlen);
303
304         replen = repbuf_size;
305         rc = ptlrpc_do_rawrpc(imp, reqbuf, reqlen,
306                               repbuf, &replen, SECINIT_RPC_TIMEOUT);
307         if (rc)
308                 GOTO(out_free, rc);
309
310         if (replen > inbuf[3]) {
311                 CERROR("output buffer size %ld too small, need %d\n",
312                         inbuf[3], replen);
313                 GOTO(out_free, rc = -EINVAL);
314         }
315
316         lsize = secinit_parse_reply(repbuf, replen,
317                                     (char *)inbuf[2], (int)inbuf[3]);
318         if (lsize < 0)
319                 GOTO(out_free, rc = (int)lsize);
320
321         copy_to_user(buffer + 3 * sizeof(long), &lsize, sizeof(lsize));
322         lsize = 0;
323         copy_to_user((char*)buffer, &lsize, sizeof(lsize));
324         rc = 0;
325 out_free:
326         class_import_put(imp);
327         if (repbuf)
328                 OBD_FREE(repbuf, repbuf_size);
329         if (reqbuf)
330                 OBD_FREE(reqbuf, reqbuf_size);
331         RETURN(rc);
332 }
333
334 static int gss_send_secfini_rpc(struct obd_import *imp,
335                                 char *reqbuf, int reqlen)
336 {
337         const int repbuf_size = 1024;
338         char *repbuf;
339         int replen = repbuf_size;
340         int rc;
341
342         OBD_ALLOC(repbuf, repbuf_size);
343         if (!repbuf) {
344                 CERROR("Out of memory\n");
345                 return -ENOMEM;
346         }
347
348         rc = ptlrpc_do_rawrpc(imp, reqbuf, reqlen, repbuf, &replen,
349                               SECFINI_RPC_TIMEOUT);
350
351         OBD_FREE(repbuf, repbuf_size);
352         return rc;
353 }
354
355 /**********************************************
356  * structure definitions                      *
357  **********************************************/
358 struct gss_sec {
359         struct ptlrpc_sec       gs_base;
360         struct gss_api_mech    *gs_mech;
361 #ifdef __KERNEL__
362         spinlock_t              gs_lock;
363         struct list_head        gs_upcalls;
364         char                    gs_pipepath[64];
365         struct dentry          *gs_depipe;
366 #endif
367 };
368
369 #ifdef __KERNEL__
370
371 static rwlock_t gss_ctx_lock = RW_LOCK_UNLOCKED;
372
373 struct gss_upcall_msg {
374         struct rpc_pipe_msg             gum_base;
375         atomic_t                        gum_refcount;
376         struct list_head                gum_list;
377         struct gss_sec                 *gum_gsec;
378         wait_queue_head_t               gum_waitq;
379         char                            gum_obdname[64];
380         uid_t                           gum_uid;
381         __u32                           gum_ip; /* XXX IPv6? */
382         __u32                           gum_svc;
383         __u32                           gum_pad;
384 };
385
386 /**********************************************
387  * rpc_pipe upcall helpers                    *
388  **********************************************/
389 static
390 void gss_release_msg(struct gss_upcall_msg *gmsg)
391 {
392         ENTRY;
393         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
394
395         if (!atomic_dec_and_test(&gmsg->gum_refcount)) {
396                 CDEBUG(D_SEC, "gmsg %p ref %d\n", gmsg,
397                        atomic_read(&gmsg->gum_refcount));
398                 EXIT;
399                 return;
400         }
401         LASSERT(list_empty(&gmsg->gum_list));
402         OBD_FREE(gmsg, sizeof(*gmsg));
403         EXIT;
404 }
405
406 static void
407 gss_unhash_msg_nolock(struct gss_upcall_msg *gmsg)
408 {
409         ENTRY;
410         if (list_empty(&gmsg->gum_list)) {
411                 EXIT;
412                 return;
413         }
414         /* FIXME should not do this. when we in upper upcall queue,
415          * downcall will call unhash_msg, thus later put_msg might
416          * free msg buffer while it's not dequeued XXX */
417         list_del_init(&gmsg->gum_base.list);
418         /* FIXME */
419
420         list_del_init(&gmsg->gum_list);
421         wake_up(&gmsg->gum_waitq);
422         atomic_dec(&gmsg->gum_refcount);
423         CDEBUG(D_SEC, "gmsg %p refcount now %d\n",
424                gmsg, atomic_read(&gmsg->gum_refcount));
425         LASSERT(atomic_read(&gmsg->gum_refcount) > 0);
426         EXIT;
427 }
428
429 static void
430 gss_unhash_msg(struct gss_upcall_msg *gmsg)
431 {
432         struct gss_sec *gsec = gmsg->gum_gsec;
433
434         spin_lock(&gsec->gs_lock);
435         gss_unhash_msg_nolock(gmsg);
436         spin_unlock(&gsec->gs_lock);
437 }
438
439 static
440 struct gss_upcall_msg * gss_find_upcall(struct gss_sec *gsec,
441                                         char *obdname,
442                                         uid_t uid, __u32 dest_ip)
443 {
444         struct gss_upcall_msg *gmsg;
445         ENTRY;
446
447         list_for_each_entry(gmsg, &gsec->gs_upcalls, gum_list) {
448                 if (gmsg->gum_uid != uid)
449                         continue;
450                 if (gmsg->gum_ip != dest_ip)
451                         continue;
452                 if (strcmp(gmsg->gum_obdname, obdname))
453                         continue;
454                 atomic_inc(&gmsg->gum_refcount);
455                 CDEBUG(D_SEC, "found gmsg at %p: obdname %s, uid %d, ref %d\n",
456                        gmsg, obdname, uid, atomic_read(&gmsg->gum_refcount));
457                 RETURN(gmsg);
458         }
459         RETURN(NULL);
460 }
461
462 static void gss_init_upcall_msg(struct gss_upcall_msg *gmsg,
463                                 struct gss_sec *gsec,
464                                 char *obdname,
465                                 uid_t uid, __u32 dest_ip, __u32 svc)
466 {
467         struct rpc_pipe_msg *rpcmsg;
468         ENTRY;
469
470         /* 2 refs: 1 for hash, 1 for current user */
471         init_waitqueue_head(&gmsg->gum_waitq);
472         list_add(&gmsg->gum_list, &gsec->gs_upcalls);
473         atomic_set(&gmsg->gum_refcount, 2);
474         gmsg->gum_gsec = gsec;
475         strncpy(gmsg->gum_obdname, obdname, sizeof(gmsg->gum_obdname));
476         gmsg->gum_uid = uid;
477         gmsg->gum_ip = dest_ip;
478         gmsg->gum_svc = svc;
479
480         rpcmsg = &gmsg->gum_base;
481         rpcmsg->data = &gmsg->gum_uid;
482         rpcmsg->len = sizeof(gmsg->gum_uid) + sizeof(gmsg->gum_ip) +
483                       sizeof(gmsg->gum_svc) + sizeof(gmsg->gum_pad);
484         EXIT;
485 }
486 #endif /* __KERNEL__ */
487
488 /********************************************
489  * gss cred manupulation helpers            *
490  ********************************************/
491 static
492 int gss_cred_is_uptodate_ctx(struct ptlrpc_cred *cred)
493 {
494         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
495         int res = 0;
496
497         read_lock(&gss_ctx_lock);
498         if ((cred->pc_flags & PTLRPC_CRED_UPTODATE) && gcred->gc_ctx)
499                 res = 1;
500         read_unlock(&gss_ctx_lock);
501         return res;
502 }
503
504 static inline
505 struct gss_cl_ctx * gss_get_ctx(struct gss_cl_ctx *ctx)
506 {
507         atomic_inc(&ctx->gc_refcount);
508         return ctx;
509 }
510
511 static
512 void gss_destroy_ctx(struct gss_cl_ctx *ctx)
513 {
514         ENTRY;
515
516         CDEBUG(D_SEC, "destroy cl_ctx %p\n", ctx);
517         if (ctx->gc_gss_ctx)
518                 kgss_delete_sec_context(&ctx->gc_gss_ctx);
519
520         if (ctx->gc_wire_ctx.len > 0) {
521                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
522                 ctx->gc_wire_ctx.len = 0;
523         }
524
525         OBD_FREE(ctx, sizeof(*ctx));
526 }
527
528 static
529 void gss_put_ctx(struct gss_cl_ctx *ctx)
530 {
531         if (atomic_dec_and_test(&ctx->gc_refcount))
532                 gss_destroy_ctx(ctx);
533 }
534
535 static
536 struct gss_cl_ctx *gss_cred_get_ctx(struct ptlrpc_cred *cred)
537 {
538         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
539         struct gss_cl_ctx *ctx = NULL;
540
541         read_lock(&gss_ctx_lock);
542         if (gcred->gc_ctx)
543                 ctx = gss_get_ctx(gcred->gc_ctx);
544         read_unlock(&gss_ctx_lock);
545         return ctx;
546 }
547
548 static
549 void gss_cred_set_ctx(struct ptlrpc_cred *cred, struct gss_cl_ctx *ctx)
550 {
551         struct gss_cred *gcred = container_of(cred, struct gss_cred, gc_base);
552         struct gss_cl_ctx *old;
553         __u64 ctx_expiry;
554         ENTRY;
555
556         if (kgss_inquire_context(ctx->gc_gss_ctx, &ctx_expiry)) {
557                 CERROR("unable to get expire time\n");
558                 ctx_expiry = 1; /* make it expired now */
559         }
560         cred->pc_expire = (unsigned long) ctx_expiry;
561
562         write_lock(&gss_ctx_lock);
563         old = gcred->gc_ctx;
564         gcred->gc_ctx = ctx;
565         cred->pc_flags |= PTLRPC_CRED_UPTODATE;
566         write_unlock(&gss_ctx_lock);
567         if (old)
568                 gss_put_ctx(old);
569
570         CWARN("client refreshed gss cred %p(uid %u)\n", cred, cred->pc_uid);
571         EXIT;
572 }
573
574 static int
575 simple_get_bytes(char **buf, __u32 *buflen, void *res, __u32 reslen)
576 {
577         if (*buflen < reslen) {
578                 CERROR("buflen %u < %u\n", *buflen, reslen);
579                 return -EINVAL;
580         }
581
582         memcpy(res, *buf, reslen);
583         *buf += reslen;
584         *buflen -= reslen;
585         return 0;
586 }
587
588 /* data passed down:
589  *  - uid
590  *  - timeout
591  *  - gc_win / error
592  *  - wire_ctx (rawobj)
593  *  - mech_ctx? (rawobj)
594  */
595 static
596 int gss_parse_init_downcall(struct gss_api_mech *gm, rawobj_t *buf,
597                             struct gss_cl_ctx **gc, struct vfs_cred *vcred,
598                             __u32 *dest_ip, int *gss_err)
599 {
600         char *p = buf->data;
601         __u32 len = buf->len;
602         struct gss_cl_ctx *ctx;
603         rawobj_t tmp_buf;
604         unsigned int timeout;
605         int err = -EIO;
606         ENTRY;
607
608         *gc = NULL;
609
610         OBD_ALLOC(ctx, sizeof(*ctx));
611         if (!ctx)
612                 RETURN(-ENOMEM);
613
614         ctx->gc_proc = RPC_GSS_PROC_DATA;
615         ctx->gc_seq = 0;
616         spin_lock_init(&ctx->gc_seq_lock);
617         atomic_set(&ctx->gc_refcount,1);
618
619         if (simple_get_bytes(&p, &len, &vcred->vc_uid, sizeof(vcred->vc_uid)))
620                 GOTO(err_free_ctx, err);
621         vcred->vc_pag = vcred->vc_uid; /* FIXME */
622         if (simple_get_bytes(&p, &len, dest_ip, sizeof(*dest_ip)))
623                 GOTO(err_free_ctx, err);
624         /* FIXME: discarded timeout for now */
625         if (simple_get_bytes(&p, &len, &timeout, sizeof(timeout)))
626                 GOTO(err_free_ctx, err);
627         *gss_err = 0;
628         if (simple_get_bytes(&p, &len, &ctx->gc_win, sizeof(ctx->gc_win)))
629                 GOTO(err_free_ctx, err);
630         /* gssd signals an error by passing ctx->gc_win = 0: */
631         if (!ctx->gc_win) {
632                 /* in which case the next int is an error code: */
633                 if (simple_get_bytes(&p, &len, gss_err, sizeof(*gss_err)))
634                         GOTO(err_free_ctx, err);
635                 GOTO(err_free_ctx, err = 0);
636         }
637         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
638                 GOTO(err_free_ctx, err);
639         if (rawobj_dup(&ctx->gc_wire_ctx, &tmp_buf)) {
640                 GOTO(err_free_ctx, err = -ENOMEM);
641         }
642         if (rawobj_extract_local(&tmp_buf, (__u32 **) ((void *)&p), &len))
643                 GOTO(err_free_wire_ctx, err);
644         if (len) {
645                 CERROR("unexpected trailing %u bytes\n", len);
646                 GOTO(err_free_wire_ctx, err);
647         }
648         if (kgss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx))
649                 GOTO(err_free_wire_ctx, err);
650
651         *gc = ctx;
652         RETURN(0);
653
654 err_free_wire_ctx:
655         if (ctx->gc_wire_ctx.data)
656                 OBD_FREE(ctx->gc_wire_ctx.data, ctx->gc_wire_ctx.len);
657 err_free_ctx:
658         OBD_FREE(ctx, sizeof(*ctx));
659         CDEBUG(D_SEC, "err_code %d, gss code %d\n", err, *gss_err);
660         return err;
661 }
662
663 /***************************************
664  * cred APIs                           *
665  ***************************************/
666 #ifdef __KERNEL__
667 static int gss_cred_refresh(struct ptlrpc_cred *cred)
668 {
669         struct obd_import          *import;
670         struct gss_sec             *gsec;
671         struct gss_upcall_msg      *gss_msg, *gss_new;
672         struct dentry              *dentry;
673         char                       *obdname, *obdtype;
674         wait_queue_t                wait;
675         uid_t                       uid = cred->pc_uid;
676         ptl_nid_t                   peer_nid;
677         __u32                       dest_ip, svc;
678         int                         res;
679         ENTRY;
680
681         if (ptlrpcs_cred_is_uptodate(cred))
682                 RETURN(0);
683
684         LASSERT(cred->pc_sec);
685         LASSERT(cred->pc_sec->ps_import);
686         LASSERT(cred->pc_sec->ps_import->imp_obd);
687
688         import = cred->pc_sec->ps_import;
689         if (!import->imp_connection) {
690                 CERROR("import has no connection set\n");
691                 RETURN(-EINVAL);
692         }
693
694         peer_nid = import->imp_connection->c_peer.peer_id.nid;
695         dest_ip = (__u32) (peer_nid & 0xFFFFFFFF);
696
697         obdtype = import->imp_obd->obd_type->typ_name;
698         if (!strcmp(obdtype, "mdc"))
699                 svc = 0;
700         else if (!strcmp(obdtype, "osc"))
701                 svc = 1;
702         else {
703                 CERROR("gss on %s?\n", obdtype);
704                 RETURN(-EINVAL);
705         }
706
707         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
708         obdname = import->imp_obd->obd_name;
709         dentry = gsec->gs_depipe;
710         gss_new = NULL;
711         res = 0;
712
713         CWARN("Initiate gss context %p(%u@%s)\n",
714                container_of(cred, struct gss_cred, gc_base),
715                uid, import->imp_target_uuid.uuid);
716
717 again:
718         spin_lock(&gsec->gs_lock);
719         gss_msg = gss_find_upcall(gsec, obdname, uid, dest_ip);
720         if (gss_msg) {
721                 spin_unlock(&gsec->gs_lock);
722                 GOTO(waiting, res);
723         }
724         if (!gss_new) {
725                 spin_unlock(&gsec->gs_lock);
726                 OBD_ALLOC(gss_new, sizeof(*gss_new));
727                 if (!gss_new) {
728                         CERROR("fail to alloc memory\n");
729                         RETURN(-ENOMEM);
730                 }
731                 goto again;
732         }
733         /* so far we'v created gss_new */
734         gss_init_upcall_msg(gss_new, gsec, obdname, uid, dest_ip, svc);
735
736         if (gss_cred_is_uptodate_ctx(cred)) {
737                 /* someone else had done it for us, simply cancel
738                  * our own upcall */
739                 CDEBUG(D_SEC, "cred("LPU64"/%u) has been refreshed by someone "
740                        "else, simply drop our request\n",
741                        cred->pc_pag, cred->pc_uid);
742                 gss_unhash_msg_nolock(gss_new);
743                 spin_unlock(&gsec->gs_lock);
744                 gss_release_msg(gss_new);
745                 RETURN(0);
746         }
747
748         /* need to make upcall now */
749         spin_unlock(&gsec->gs_lock);
750         res = rpc_queue_upcall(dentry->d_inode, &gss_new->gum_base);
751         if (res) {
752                 CERROR("rpc_queue_upcall failed: %d\n", res);
753                 gss_unhash_msg(gss_new);
754                 gss_release_msg(gss_new);
755                 RETURN(res);
756         }
757         gss_msg = gss_new;
758
759 waiting:
760         init_waitqueue_entry(&wait, current);
761         spin_lock(&gsec->gs_lock);
762         add_wait_queue(&gss_msg->gum_waitq, &wait);
763         set_current_state(TASK_INTERRUPTIBLE);
764         spin_unlock(&gsec->gs_lock);
765
766         schedule();
767
768         remove_wait_queue(&gss_msg->gum_waitq, &wait);
769         if (signal_pending(current)) {
770                 CERROR("interrupted gss upcall %p\n", gss_msg);
771                 res = -EINTR;
772         }
773         gss_release_msg(gss_msg);
774         RETURN(res);
775 }
776 #else /* !__KERNEL__ */
777 extern int lgss_handle_krb5_upcall(uid_t uid, __u32 dest_ip,
778                                    char *obd_name,
779                                    char *buf, int bufsize,
780                                    int (*callback)(char*, unsigned long));
781
782 static int gss_cred_refresh(struct ptlrpc_cred *cred)
783 {
784         char                    buf[4096];
785         rawobj_t                obj;
786         struct obd_import      *imp;
787         struct gss_sec         *gsec;
788         struct gss_api_mech    *mech;
789         struct gss_cl_ctx      *ctx = NULL;
790         struct vfs_cred         vcred = { 0 };
791         ptl_nid_t               peer_nid;
792         __u32                   dest_ip;
793         __u32                   subflavor;
794         int                     rc, gss_err;
795
796         LASSERT(cred);
797         LASSERT(cred->pc_sec);
798         LASSERT(cred->pc_sec->ps_import);
799         LASSERT(cred->pc_sec->ps_import->imp_obd);
800
801         if (ptlrpcs_cred_is_uptodate(cred))
802                 RETURN(0);
803
804         imp = cred->pc_sec->ps_import;
805         peer_nid = imp->imp_connection->c_peer.peer_id.nid;
806         dest_ip = (__u32) (peer_nid & 0xFFFFFFFF);
807         subflavor = cred->pc_sec->ps_flavor.subflavor;
808
809         if (subflavor != PTLRPC_SEC_GSS_KRB5I) {
810                 CERROR("unknown subflavor %u\n", subflavor);
811                 GOTO(err_out, rc = -EINVAL);
812         }
813
814         rc = lgss_handle_krb5_upcall(cred->pc_uid, dest_ip,
815                                      imp->imp_obd->obd_name,
816                                      buf, sizeof(buf),
817                                      gss_send_secinit_rpc);
818         LASSERT(rc != 0);
819         if (rc < 0)
820                 goto err_out;
821
822         obj.data = buf;
823         obj.len = rc;
824
825         gsec = container_of(cred->pc_sec, struct gss_sec, gs_base);
826         mech = gsec->gs_mech;
827         LASSERT(mech);
828         rc = gss_parse_init_downcall(mech, &obj, &ctx, &vcred, &dest_ip,
829                                      &gss_err);
830         if (rc) {
831                 CERROR("parse init downcall error %d\n", rc);
832                 goto err_out;
833         }
834
835         if (gss_err) {
836                 CERROR("cred fresh got gss error %x\n", gss_err);
837                 rc = -EINVAL;
838                 goto err_out;
839         }
840
841         gss_cred_set_ctx(cred, ctx);
842         LASSERT(gss_cred_is_uptodate_ctx(cred));
843
844         return 0;
845 err_out:
846         cred->pc_flags |= PTLRPC_CRED_DEAD;
847         return rc;
848 }
849 #endif
850
851 static int gss_cred_match(struct ptlrpc_cred *cred,
852                           struct ptlrpc_request *req,
853                           struct vfs_cred *vcred)
854 {
855         RETURN(cred->pc_pag == vcred->vc_pag);
856 }
857
858 static int gss_cred_sign(struct ptlrpc_cred *cred,
859                          struct ptlrpc_request *req)
860 {
861         struct gss_cred         *gcred;
862         struct gss_cl_ctx       *ctx;
863         rawobj_t                 lmsg, mic;
864         __u32                   *vp, *vpsave, vlen, seclen;
865         __u32                    seqnum, major, rc = 0;
866         ENTRY;
867
868         LASSERT(req->rq_reqbuf);
869         LASSERT(req->rq_cred == cred);
870
871         gcred = container_of(cred, struct gss_cred, gc_base);
872         ctx = gss_cred_get_ctx(cred);
873         if (!ctx) {
874                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
875                         cred, cred->pc_pag, cred->pc_uid);
876                 RETURN(-EPERM);
877         }
878
879         lmsg.len = req->rq_reqlen;
880         lmsg.data = (__u8 *) req->rq_reqmsg;
881
882         vp = (__u32 *) (lmsg.data + lmsg.len);
883         vlen = req->rq_reqbuf_len - sizeof(struct ptlrpcs_wire_hdr) -
884                lmsg.len;
885         seclen = vlen;
886
887         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
888                 CERROR("vlen %d, need %d\n",
889                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
890                 rc = -EIO;
891                 goto out;
892         }
893
894         spin_lock(&ctx->gc_seq_lock);
895         seqnum = ctx->gc_seq++;
896         spin_unlock(&ctx->gc_seq_lock);
897
898         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
899         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);      /* subflavor */
900         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
901         *vp++ = cpu_to_le32(seqnum);                    /* seq */
902         *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_INTEGRITY);  /* service */
903         vlen -= 5 * 4;
904
905         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
906                 rc = -EIO;
907                 goto out;
908         }
909         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
910
911         vpsave = vp++;  /* reserve for size */
912         vlen -= 4;
913
914         mic.len = vlen;
915         mic.data = (char *) vp;
916
917         CDEBUG(D_SEC, "reqbuf at %p, lmsg at %p, len %d, mic at %p, len %d\n",
918                req->rq_reqbuf, lmsg.data, lmsg.len, mic.data, mic.len);
919         major = kgss_get_mic(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT, &lmsg, &mic);
920         if (major) {
921                 CERROR("gss compute mic error, major %x\n", major);
922                 rc = -EACCES;
923                 goto out;
924         }
925
926         *vpsave = cpu_to_le32(mic.len);
927         
928         seclen = seclen - vlen + mic.len;
929         buf_to_sec_hdr(req->rq_reqbuf)->sec_len = cpu_to_le32(seclen);
930         req->rq_reqdata_len += size_round(seclen);
931         CDEBUG(D_SEC, "msg size %d, checksum size %d, total sec size %d\n",
932                lmsg.len, mic.len, seclen);
933 out:
934         gss_put_ctx(ctx);
935         RETURN(rc);
936 }
937
938 static int gss_cred_verify(struct ptlrpc_cred *cred,
939                            struct ptlrpc_request *req)
940 {
941         struct gss_cred        *gcred;
942         struct gss_cl_ctx      *ctx;
943         struct ptlrpcs_wire_hdr *sec_hdr;
944         rawobj_t                lmsg, mic;
945         __u32                   *vp, vlen, subflavor, proc, seq, svc;
946         __u32                   major, minor, rc;
947         ENTRY;
948
949         LASSERT(req->rq_repbuf);
950         LASSERT(req->rq_cred == cred);
951
952         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
953         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr) + sec_hdr->msg_len);
954         vlen = sec_hdr->sec_len;
955
956         if (vlen < 7 * 4) {
957                 CERROR("reply sec size %u too small\n", vlen);
958                 RETURN(-EPROTO);
959         }
960
961         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
962                 CERROR("reply have different gss version\n");
963                 RETURN(-EPROTO);
964         }
965         subflavor = le32_to_cpu(*vp++);
966         proc = le32_to_cpu(*vp++);
967         vlen -= 3 * 4;
968
969         switch (proc) {
970         case PTLRPC_GSS_PROC_DATA:
971                 seq = le32_to_cpu(*vp++);
972                 svc = le32_to_cpu(*vp++);
973                 if (svc != PTLRPC_GSS_SVC_INTEGRITY) {
974                         CERROR("Unknown svc %d\n", svc);
975                         RETURN(-EPROTO);
976                 }
977                 if (*vp++ != 0) {
978                         CERROR("Unexpected ctx handle\n");
979                         RETURN(-EPROTO);
980                 }
981                 mic.len = le32_to_cpu(*vp++);
982                 vlen -= 4 * 4;
983                 if (vlen < mic.len) {
984                         CERROR("vlen %d, mic.len %d\n", vlen, mic.len);
985                         RETURN(-EINVAL);
986                 }
987                 mic.data = (char *) vp;
988
989                 gcred = container_of(cred, struct gss_cred, gc_base);
990                 ctx = gss_cred_get_ctx(cred);
991                 LASSERT(ctx);
992
993                 lmsg.len = sec_hdr->msg_len;
994                 lmsg.data = (__u8 *) buf_to_lustre_msg(req->rq_repbuf);
995
996                 major = kgss_verify_mic(ctx->gc_gss_ctx, &lmsg, &mic, NULL);
997                 if (major != GSS_S_COMPLETE) {
998                         CERROR("gss verify mic error: major %x\n", major);
999                         GOTO(proc_data_out, rc = -EINVAL);
1000                 }
1001
1002                 req->rq_repmsg = (struct lustre_msg *) lmsg.data;
1003                 req->rq_replen = lmsg.len;
1004
1005                 /* here we could check the seq number is the same one
1006                  * we sent to server. but portals has prevent us from
1007                  * replay attack, so maybe we don't need check it again.
1008                  */
1009                 rc = 0;
1010 proc_data_out:
1011                 gss_put_ctx(ctx);
1012                 break;
1013         case PTLRPC_GSS_PROC_ERR:
1014                 major = le32_to_cpu(*vp++);
1015                 minor = le32_to_cpu(*vp++);
1016                 /* server return NO_CONTEXT might be caused by context expire
1017                  * or server reboot/failover. we refresh the cred transparently
1018                  * to upper layer.
1019                  * In some cases, our gss handle is possible to be incidentally
1020                  * identical to another handle since the handle itself is not
1021                  * fully random. In krb5 case, the GSS_S_BAD_SIG will be
1022                  * returned, maybe other gss error for other mechanism. Here we
1023                  * only consider krb5 mech (FIXME) and try to establish new
1024                  * context.
1025                  */
1026                 if (major == GSS_S_NO_CONTEXT ||
1027                     major == GSS_S_BAD_SIG) {
1028                         CWARN("req %p: server report cred %p %s, expired?\n",
1029                                req, cred, (major == GSS_S_NO_CONTEXT) ?
1030                                            "NO_CONTEXT" : "BAD_SIG");
1031
1032                         ptlrpcs_cred_die(cred);
1033                         rc = ptlrpcs_req_replace_dead_cred(req);
1034                         if (!rc)
1035                                 req->rq_ptlrpcs_restart = 1;
1036                         else
1037                                 CERROR("replace dead cred failed %d\n", rc);
1038                 } else {
1039                         CERROR("Unrecognized gss error (%x/%x)\n",
1040                                 major, minor);
1041                         rc = -EACCES;
1042                 }
1043                 break;
1044         default:
1045                 CERROR("unknown gss proc %d\n", proc);
1046                 rc = -EPROTO;
1047         }
1048
1049         RETURN(rc);
1050 }
1051
1052 static int gss_cred_seal(struct ptlrpc_cred *cred,
1053                          struct ptlrpc_request *req)
1054 {
1055         struct gss_cred         *gcred;
1056         struct gss_cl_ctx       *ctx;
1057         struct ptlrpcs_wire_hdr *sec_hdr;
1058         rawobj_buf_t             msg_buf;
1059         rawobj_t                 cipher_buf;
1060         __u32                   *vp, *vpsave, vlen, seclen;
1061         __u32                    major, seqnum, rc = 0;
1062         ENTRY;
1063
1064         LASSERT(req->rq_reqbuf);
1065         LASSERT(req->rq_cred == cred);
1066
1067         gcred = container_of(cred, struct gss_cred, gc_base);
1068         ctx = gss_cred_get_ctx(cred);
1069         if (!ctx) {
1070                 CERROR("cred %p("LPU64"/%u) invalidated?\n",
1071                         cred, cred->pc_pag, cred->pc_uid);
1072                 RETURN(-EPERM);
1073         }
1074
1075         vp = (__u32 *) (req->rq_reqbuf + sizeof(*sec_hdr));
1076         vlen = req->rq_reqbuf_len - sizeof(*sec_hdr);
1077         seclen = vlen;
1078
1079         if (vlen < 6 * 4 + size_round4(ctx->gc_wire_ctx.len)) {
1080                 CERROR("vlen %d, need %d\n",
1081                         vlen, 6 * 4 + size_round4(ctx->gc_wire_ctx.len));
1082                 rc = -EIO;
1083                 goto out;
1084         }
1085
1086         spin_lock(&ctx->gc_seq_lock);
1087         seqnum = ctx->gc_seq++;
1088         spin_unlock(&ctx->gc_seq_lock);
1089
1090         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);    /* version */
1091         *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5P);      /* subflavor */
1092         *vp++ = cpu_to_le32(ctx->gc_proc);              /* proc */
1093         *vp++ = cpu_to_le32(seqnum);                    /* seq */
1094         *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_PRIVACY);    /* service */
1095         vlen -= 5 * 4;
1096
1097         if (rawobj_serialize(&ctx->gc_wire_ctx, &vp, &vlen)) {
1098                 rc = -EIO;
1099                 goto out;
1100         }
1101         CDEBUG(D_SEC, "encoded wire_ctx length %d\n", ctx->gc_wire_ctx.len);
1102
1103         vpsave = vp++;  /* reserve for size */
1104         vlen -= 4;
1105
1106         msg_buf.buf = (__u8 *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1107         msg_buf.buflen = req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1108         msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1109         msg_buf.datalen = req->rq_reqlen;
1110
1111         cipher_buf.data = (__u8 *) vp;
1112         cipher_buf.len = vlen;
1113
1114         major = kgss_wrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1115                           &msg_buf, &cipher_buf);
1116         if (major) {
1117                 CERROR("error wrap: major 0x%x\n", major);
1118                 GOTO(out, rc = -EINVAL);
1119         }
1120
1121         *vpsave = cpu_to_le32(cipher_buf.len);
1122
1123         seclen = seclen - vlen + cipher_buf.len;
1124         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1125         sec_hdr->sec_len = cpu_to_le32(seclen);
1126         req->rq_reqdata_len += size_round(seclen);
1127
1128         CDEBUG(D_SEC, "msg size %d, total sec size %d\n",
1129                req->rq_reqlen, seclen);
1130 out:
1131         gss_put_ctx(ctx);
1132         RETURN(rc);
1133 }
1134
1135 static int gss_cred_unseal(struct ptlrpc_cred *cred,
1136                            struct ptlrpc_request *req)
1137 {
1138         struct gss_cred        *gcred;
1139         struct gss_cl_ctx      *ctx;
1140         struct ptlrpcs_wire_hdr *sec_hdr;
1141         rawobj_t                cipher_text, plain_text;
1142         __u32                   *vp, vlen, subflavor, proc, seq, svc;
1143         int                     rc;
1144         ENTRY;
1145
1146         LASSERT(req->rq_repbuf);
1147         LASSERT(req->rq_cred == cred);
1148
1149         sec_hdr = buf_to_sec_hdr(req->rq_repbuf);
1150         if (sec_hdr->msg_len != 0) {
1151                 CERROR("unexpected msg_len %u\n", sec_hdr->msg_len);
1152                 RETURN(-EPROTO);
1153         }
1154
1155         vp = (__u32 *) (req->rq_repbuf + sizeof(*sec_hdr));
1156         vlen = sec_hdr->sec_len;
1157
1158         if (vlen < 7 * 4) {
1159                 CERROR("reply sec size %u too small\n", vlen);
1160                 RETURN(-EPROTO);
1161         }
1162
1163         if (*vp++ != cpu_to_le32(PTLRPC_SEC_GSS_VERSION)) {
1164                 CERROR("reply have different gss version\n");
1165                 RETURN(-EPROTO);
1166         }
1167         subflavor = le32_to_cpu(*vp++);
1168         proc = le32_to_cpu(*vp++);
1169         seq = le32_to_cpu(*vp++);
1170         svc = le32_to_cpu(*vp++);
1171         vlen -= 5 * 4;
1172
1173         switch (proc) {
1174         case PTLRPC_GSS_PROC_DATA:
1175                 if (svc != PTLRPC_GSS_SVC_PRIVACY) {
1176                         CERROR("Unknown svc %d\n", svc);
1177                         RETURN(-EPROTO);
1178                 }
1179                 if (*vp++ != 0) {
1180                         CERROR("Unexpected ctx handle\n");
1181                         RETURN(-EPROTO);
1182                 }
1183                 vlen -= 4;
1184
1185                 cipher_text.len = le32_to_cpu(*vp++);
1186                 cipher_text.data = (__u8 *) vp;
1187                 vlen -= 4;
1188
1189                 if (vlen < cipher_text.len) {
1190                         CERROR("cipher text to be %u while buf only %u\n",
1191                                 cipher_text.len, vlen);
1192                         RETURN(-EPROTO);
1193                 }
1194
1195                 plain_text = cipher_text;
1196
1197                 gcred = container_of(cred, struct gss_cred, gc_base);
1198                 ctx = gss_cred_get_ctx(cred);
1199                 LASSERT(ctx);
1200
1201                 rc = kgss_unwrap(ctx->gc_gss_ctx, GSS_C_QOP_DEFAULT,
1202                                  &cipher_text, &plain_text);
1203                 if (rc) {
1204                         CERROR("error unwrap: 0x%x\n", rc);
1205                         GOTO(proc_out, rc = -EINVAL);
1206                 }
1207
1208                 req->rq_repmsg = (struct lustre_msg *) vp;
1209                 req->rq_replen = plain_text.len;
1210
1211                 rc = 0;
1212 proc_out:
1213                 gss_put_ctx(ctx);
1214                 break;
1215         default:
1216                 CERROR("unknown gss proc %d\n", proc);
1217                 rc = -EPROTO;
1218         }
1219
1220         RETURN(rc);
1221 }
1222
1223 static void destroy_gss_context(struct ptlrpc_cred *cred)
1224 {
1225         struct ptlrpcs_wire_hdr *hdr;
1226         struct lustre_msg       *lmsg;
1227         struct gss_cred         *gcred;
1228         struct ptlrpc_request    req;
1229         struct obd_import       *imp;
1230         __u32                   *vp, lmsg_size;
1231         ENTRY;
1232
1233         /* cred's refcount is 0, steal one */
1234         atomic_inc(&cred->pc_refcount);
1235
1236         gcred = container_of(cred, struct gss_cred, gc_base);
1237         gcred->gc_ctx->gc_proc = PTLRPC_GSS_PROC_DESTROY;
1238         imp = cred->pc_sec->ps_import;
1239         LASSERT(imp);
1240
1241         if (!(cred->pc_flags & PTLRPC_CRED_UPTODATE)) {
1242                 CWARN("Destroy a dead gss cred %p(%u@%s), don't send rpc\n",
1243                        gcred, cred->pc_uid, imp->imp_target_uuid.uuid);
1244                 atomic_dec(&cred->pc_refcount);
1245                 EXIT;
1246                 return;
1247         }
1248
1249         CWARN("client destroy gss cred %p(%u@%s)\n",
1250                gcred, cred->pc_uid, imp->imp_target_uuid.uuid);
1251
1252         lmsg_size = lustre_msg_size(0, NULL);
1253         req.rq_reqbuf_len = sizeof(*hdr) + lmsg_size +
1254                             ptlrpcs_est_req_payload(cred->pc_sec, lmsg_size);
1255
1256         OBD_ALLOC(req.rq_reqbuf, req.rq_reqbuf_len);
1257         if (!req.rq_reqbuf) {
1258                 CERROR("Fail to alloc reqbuf, cancel anyway\n");
1259                 atomic_dec(&cred->pc_refcount);
1260                 EXIT;
1261                 return;
1262         }
1263
1264         /* wire hdr */
1265         hdr = buf_to_sec_hdr(req.rq_reqbuf);
1266         hdr->flavor  = cpu_to_le32(PTLRPC_SEC_GSS);
1267         hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_AUTH);
1268         hdr->msg_len = cpu_to_le32(lmsg_size);
1269         hdr->sec_len = cpu_to_le32(0);
1270
1271         /* lustre message */
1272         lmsg = buf_to_lustre_msg(req.rq_reqbuf);
1273         lustre_init_msg(lmsg, 0, NULL, NULL);
1274         lmsg->handle   = imp->imp_remote_handle;
1275         lmsg->type     = PTL_RPC_MSG_REQUEST;
1276         lmsg->opc      = SEC_FINI;
1277         lmsg->flags    = 0;
1278         lmsg->conn_cnt = imp->imp_conn_cnt;
1279         /* add this for randomize */
1280         get_random_bytes(&lmsg->last_xid, sizeof(lmsg->last_xid));
1281         get_random_bytes(&lmsg->transno, sizeof(lmsg->transno));
1282
1283         vp = (__u32 *) req.rq_reqbuf;
1284
1285         req.rq_cred = cred;
1286         req.rq_reqmsg = buf_to_lustre_msg(req.rq_reqbuf);
1287         req.rq_reqlen = lmsg_size;
1288         req.rq_reqdata_len = sizeof(*hdr) + lmsg_size;
1289
1290         if (gss_cred_sign(cred, &req)) {
1291                 CERROR("failed to sign, cancel anyway\n");
1292                 atomic_dec(&cred->pc_refcount);
1293                 goto exit;
1294         }
1295         atomic_dec(&cred->pc_refcount);
1296
1297         /* send out */
1298         gss_send_secfini_rpc(imp, req.rq_reqbuf, req.rq_reqdata_len);
1299 exit:
1300         OBD_FREE(req.rq_reqbuf, req.rq_reqbuf_len);
1301         EXIT;
1302 }
1303
1304 static void gss_cred_destroy(struct ptlrpc_cred *cred)
1305 {
1306         struct gss_cred *gcred;
1307         ENTRY;
1308
1309         LASSERT(cred);
1310         LASSERT(!atomic_read(&cred->pc_refcount));
1311
1312         gcred = container_of(cred, struct gss_cred, gc_base);
1313         if (gcred->gc_ctx) {
1314                 destroy_gss_context(cred);
1315                 gss_put_ctx(gcred->gc_ctx);
1316         }
1317
1318         CDEBUG(D_SEC, "GSS_SEC: destroy cred %p\n", gcred);
1319
1320         OBD_FREE(gcred, sizeof(*gcred));
1321         EXIT;
1322 }
1323
1324 static struct ptlrpc_credops gss_credops = {
1325         .refresh        = gss_cred_refresh,
1326         .match          = gss_cred_match,
1327         .sign           = gss_cred_sign,
1328         .verify         = gss_cred_verify,
1329         .seal           = gss_cred_seal,
1330         .unseal         = gss_cred_unseal,
1331         .destroy        = gss_cred_destroy,
1332 };
1333
1334 #ifdef __KERNEL__
1335 /*******************************************
1336  * rpc_pipe APIs                           *
1337  *******************************************/
1338 static ssize_t
1339 gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
1340                 char *dst, size_t buflen)
1341 {
1342         char *data = (char *)msg->data + msg->copied;
1343         ssize_t mlen = msg->len;
1344         ssize_t left;
1345         ENTRY;
1346
1347         if (mlen > buflen)
1348                 mlen = buflen;
1349         left = copy_to_user(dst, data, mlen);
1350         if (left < 0) {
1351                 msg->errno = left;
1352                 RETURN(left);
1353         }
1354         mlen -= left;
1355         msg->copied += mlen;
1356         msg->errno = 0;
1357         RETURN(mlen);
1358 }
1359
1360 static ssize_t
1361 gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
1362 {
1363         char *buf;
1364         const int bufsize = 1024;
1365         rawobj_t obj;
1366         struct inode *inode = filp->f_dentry->d_inode;
1367         struct rpc_inode *rpci = RPC_I(inode);
1368         struct obd_import *import;
1369         struct ptlrpc_sec *sec;
1370         struct gss_sec *gsec;
1371         char *obdname;
1372         struct gss_api_mech *mech;
1373         struct vfs_cred vcred = { 0 };
1374         struct ptlrpc_cred *cred;
1375         struct gss_upcall_msg *gss_msg;
1376         struct gss_cl_ctx *ctx = NULL;
1377         __u32  dest_ip;
1378         ssize_t left;
1379         int err, gss_err;
1380         ENTRY;
1381
1382         if (mlen > bufsize) {
1383                 CERROR("mlen %ld > bufsize %d\n", (long)mlen, bufsize);
1384                 RETURN(-ENOSPC);
1385         }
1386
1387         OBD_ALLOC(buf, bufsize);
1388         if (!buf) {
1389                 CERROR("alloc mem failed\n");
1390                 RETURN(-ENOMEM);
1391         }
1392
1393         left = copy_from_user(buf, src, mlen);
1394         if (left)
1395                 GOTO(err_free, err = -EFAULT);
1396
1397         obj.data = buf;
1398         obj.len = mlen;
1399
1400         LASSERT(rpci->private);
1401         gsec = (struct gss_sec *)rpci->private;
1402         sec = &gsec->gs_base;
1403         LASSERT(sec->ps_import);
1404         import = class_import_get(sec->ps_import);
1405         LASSERT(import->imp_obd);
1406         obdname = import->imp_obd->obd_name;
1407         mech = gsec->gs_mech;
1408
1409         err = gss_parse_init_downcall(mech, &obj, &ctx, &vcred, &dest_ip,
1410                                       &gss_err);
1411         if (err) {
1412                 CERROR("parse downcall err %d\n", err);
1413                 GOTO(err, err);
1414         }
1415         cred = ptlrpcs_cred_lookup(sec, &vcred);
1416         if (!cred) {
1417                 CWARN("didn't find cred\n");
1418                 GOTO(err, err);
1419         }
1420         if (gss_err) {
1421                 CERROR("got gss err %d, set cred %p dead\n", gss_err, cred);
1422                 cred->pc_flags |= PTLRPC_CRED_DEAD;
1423         } else {
1424                 CDEBUG(D_SEC, "get initial ctx:\n");
1425                 gss_cred_set_ctx(cred, ctx);
1426         }
1427
1428         spin_lock(&gsec->gs_lock);
1429         gss_msg = gss_find_upcall(gsec, obdname, vcred.vc_uid, dest_ip);
1430         if (gss_msg) {
1431                 gss_unhash_msg_nolock(gss_msg);
1432                 spin_unlock(&gsec->gs_lock);
1433                 gss_release_msg(gss_msg);
1434         } else
1435                 spin_unlock(&gsec->gs_lock);
1436
1437         ptlrpcs_cred_put(cred, 1);
1438         class_import_put(import);
1439         OBD_FREE(buf, bufsize);
1440         RETURN(mlen);
1441 err:
1442         if (ctx)
1443                 gss_destroy_ctx(ctx);
1444         class_import_put(import);
1445 err_free:
1446         OBD_FREE(buf, bufsize);
1447         CDEBUG(D_SEC, "gss_pipe_downcall returning %d\n", err);
1448         RETURN(err);
1449 }
1450
1451 static
1452 void gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
1453 {
1454         struct gss_upcall_msg *gmsg;
1455         static unsigned long ratelimit;
1456         ENTRY;
1457
1458         if (msg->errno >= 0) {
1459                 EXIT;
1460                 return;
1461         }
1462
1463         gmsg = container_of(msg, struct gss_upcall_msg, gum_base);
1464         CDEBUG(D_SEC, "destroy gmsg %p\n", gmsg);
1465         atomic_inc(&gmsg->gum_refcount);
1466         gss_unhash_msg(gmsg);
1467         if (msg->errno == -ETIMEDOUT || msg->errno == -EPIPE) {
1468                 unsigned long now = get_seconds();
1469                 if (time_after(now, ratelimit)) {
1470                         CWARN("GSS_SEC upcall timed out.\n"
1471                               "Please check user daemon is running!\n");
1472                         ratelimit = now + 15;
1473                 }
1474         }
1475         gss_release_msg(gmsg);
1476         EXIT;
1477 }
1478
1479 static
1480 void gss_pipe_release(struct inode *inode)
1481 {
1482         struct rpc_inode *rpci = RPC_I(inode);
1483         struct ptlrpc_sec *sec;
1484         struct gss_sec *gsec;
1485         ENTRY;
1486
1487         gsec = (struct gss_sec *)rpci->private;
1488         sec = &gsec->gs_base;
1489         spin_lock(&gsec->gs_lock);
1490         while (!list_empty(&gsec->gs_upcalls)) {
1491                 struct gss_upcall_msg *gmsg;
1492
1493                 gmsg = list_entry(gsec->gs_upcalls.next,
1494                                   struct gss_upcall_msg, gum_list);
1495                 gmsg->gum_base.errno = -EPIPE;
1496                 atomic_inc(&gmsg->gum_refcount);
1497                 gss_unhash_msg_nolock(gmsg);
1498                 gss_release_msg(gmsg);
1499         }
1500         spin_unlock(&gsec->gs_lock);
1501         EXIT;
1502 }
1503
1504 static struct rpc_pipe_ops gss_upcall_ops = {
1505         .upcall         = gss_pipe_upcall,
1506         .downcall       = gss_pipe_downcall,
1507         .destroy_msg    = gss_pipe_destroy_msg,
1508         .release_pipe   = gss_pipe_release,
1509 };
1510 #endif /* __KERNEL__ */
1511
1512 /*********************************************
1513  * GSS security APIs                         *
1514  *********************************************/
1515
1516 static
1517 struct ptlrpc_sec* gss_create_sec(ptlrpcs_flavor_t *flavor,
1518                                   const char *pipe_dir,
1519                                   void *pipe_data)
1520 {
1521         struct gss_sec *gsec;
1522         struct ptlrpc_sec *sec;
1523 #ifdef __KERNEL__
1524         char *pos;
1525 #endif
1526         ENTRY;
1527
1528         LASSERT(flavor->flavor == PTLRPC_SEC_GSS);
1529
1530         OBD_ALLOC(gsec, sizeof(*gsec));
1531         if (!gsec) {
1532                 CERROR("can't alloc gsec\n");
1533                 RETURN(NULL);
1534         }
1535
1536         gsec->gs_mech = kgss_subflavor_to_mech(flavor->subflavor);
1537         if (!gsec->gs_mech) {
1538                 CERROR("subflavor %d not found\n", flavor->subflavor);
1539                 goto err_free;
1540         }
1541
1542         /* initialize gss sec */
1543 #ifdef __KERNEL__
1544         INIT_LIST_HEAD(&gsec->gs_upcalls);
1545         spin_lock_init(&gsec->gs_lock);
1546
1547         snprintf(gsec->gs_pipepath, sizeof(gsec->gs_pipepath),
1548                  LUSTRE_PIPEDIR"/%s", pipe_dir);
1549         if (IS_ERR(rpc_mkdir(gsec->gs_pipepath, NULL))) {
1550                 CERROR("can't make pipedir %s\n", gsec->gs_pipepath);
1551                 goto err_mech_put;
1552         }
1553
1554         snprintf(gsec->gs_pipepath, sizeof(gsec->gs_pipepath),
1555                  LUSTRE_PIPEDIR"/%s/%s", pipe_dir, gsec->gs_mech->gm_name); 
1556         gsec->gs_depipe = rpc_mkpipe(gsec->gs_pipepath, gsec,
1557                                      &gss_upcall_ops, RPC_PIPE_WAIT_FOR_OPEN);
1558         if (IS_ERR(gsec->gs_depipe)) {
1559                 CERROR("failed to make rpc_pipe %s: %ld\n",
1560                         gsec->gs_pipepath, PTR_ERR(gsec->gs_depipe));
1561                 goto err_rmdir;
1562         }
1563         CDEBUG(D_SEC, "gss sec %p, pipe path %s\n", gsec, gsec->gs_pipepath);
1564 #endif
1565
1566         sec = &gsec->gs_base;
1567
1568         switch (flavor->subflavor) {
1569         case PTLRPC_SEC_GSS_KRB5I:
1570                 sec->ps_sectype = PTLRPC_SEC_TYPE_AUTH;
1571                 break;
1572         case PTLRPC_SEC_GSS_KRB5P:
1573                 sec->ps_sectype = PTLRPC_SEC_TYPE_PRIV;
1574                 break;
1575         default:
1576                 LBUG();
1577         }
1578
1579         sec->ps_expire = GSS_CREDCACHE_EXPIRE;
1580         sec->ps_nextgc = get_seconds() + sec->ps_expire;
1581         sec->ps_flags = 0;
1582
1583         CDEBUG(D_SEC, "Create GSS security instance at %p(external %p)\n",
1584                gsec, sec);
1585         RETURN(sec);
1586
1587 #ifdef __KERNEL__
1588 err_rmdir:
1589         pos = strrchr(gsec->gs_pipepath, '/');
1590         LASSERT(pos);
1591         *pos = 0;
1592         rpc_rmdir(gsec->gs_pipepath);
1593 err_mech_put:
1594 #endif
1595         kgss_mech_put(gsec->gs_mech);
1596 err_free:
1597         OBD_FREE(gsec, sizeof(*gsec));
1598         RETURN(NULL);
1599 }
1600
1601 static
1602 void gss_destroy_sec(struct ptlrpc_sec *sec)
1603 {
1604         struct gss_sec *gsec;
1605 #ifdef __KERNEL__
1606         char *pos;
1607 #endif
1608         ENTRY;
1609
1610         gsec = container_of(sec, struct gss_sec, gs_base);
1611         CDEBUG(D_SEC, "Destroy GSS security instance at %p\n", gsec);
1612
1613         LASSERT(gsec->gs_mech);
1614         LASSERT(!atomic_read(&sec->ps_refcount));
1615         LASSERT(!atomic_read(&sec->ps_credcount));
1616 #ifdef __KERNEL__
1617         rpc_unlink(gsec->gs_pipepath);
1618         pos = strrchr(gsec->gs_pipepath, '/');
1619         LASSERT(pos);
1620         *pos = 0;
1621         rpc_rmdir(gsec->gs_pipepath);
1622 #endif
1623
1624         kgss_mech_put(gsec->gs_mech);
1625         OBD_FREE(gsec, sizeof(*gsec));
1626         EXIT;
1627 }
1628
1629 static
1630 struct ptlrpc_cred * gss_create_cred(struct ptlrpc_sec *sec,
1631                                      struct ptlrpc_request *req,
1632                                      struct vfs_cred *vcred)
1633 {
1634         struct gss_cred *gcred;
1635         struct ptlrpc_cred *cred;
1636         ENTRY;
1637
1638         OBD_ALLOC(gcred, sizeof(*gcred));
1639         if (!gcred)
1640                 RETURN(NULL);
1641
1642         cred = &gcred->gc_base;
1643         INIT_LIST_HEAD(&cred->pc_hash);
1644         atomic_set(&cred->pc_refcount, 0);
1645         cred->pc_sec = sec;
1646         cred->pc_ops = &gss_credops;
1647         cred->pc_req = req;
1648         cred->pc_expire = get_seconds() + GSS_CRED_EXPIRE;
1649         cred->pc_flags = 0;
1650         cred->pc_pag = vcred->vc_pag;
1651         cred->pc_uid = vcred->vc_uid;
1652         CDEBUG(D_SEC, "create a gss cred at %p("LPU64"/%u)\n",
1653                cred, vcred->vc_pag, vcred->vc_uid);
1654
1655         RETURN(cred);
1656 }
1657
1658 static int gss_estimate_payload(struct ptlrpc_sec *sec, int msgsize)
1659 {
1660         switch (sec->ps_sectype) {
1661         case PTLRPC_SEC_TYPE_AUTH:
1662                 return GSS_MAX_AUTH_PAYLOAD;
1663         case PTLRPC_SEC_TYPE_PRIV:
1664                 return size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1665                                     GSS_PRIVBUF_PREFIX_LEN +
1666                                     GSS_PRIVBUF_SUFFIX_LEN);
1667         default:
1668                 LBUG();
1669                 return 0;
1670         }
1671 }
1672
1673 static int gss_alloc_reqbuf(struct ptlrpc_sec *sec,
1674                             struct ptlrpc_request *req,
1675                             int lmsg_size)
1676 {
1677         int msg_payload, sec_payload;
1678         int privacy, rc;
1679         ENTRY;
1680
1681         /* In PRIVACY mode, lustre message is always 0 (already encoded into
1682          * security payload).
1683          */
1684         privacy = sec->ps_sectype == PTLRPC_SEC_TYPE_PRIV;
1685         msg_payload = privacy ? 0 : lmsg_size;
1686         sec_payload = gss_estimate_payload(sec, lmsg_size);
1687
1688         rc = sec_alloc_reqbuf(sec, req, msg_payload, sec_payload);
1689         if (rc)
1690                 return rc;
1691
1692         if (privacy) {
1693                 int buflen = lmsg_size + GSS_PRIVBUF_PREFIX_LEN +
1694                              GSS_PRIVBUF_SUFFIX_LEN;
1695                 char *buf;
1696
1697                 OBD_ALLOC(buf, buflen);
1698                 if (!buf) {
1699                         CERROR("Fail to alloc %d\n", buflen);
1700                         sec_free_reqbuf(sec, req);
1701                         RETURN(-ENOMEM);
1702                 }
1703                 req->rq_reqmsg = (struct lustre_msg *)
1704                                         (buf + GSS_PRIVBUF_PREFIX_LEN);
1705         }
1706
1707         RETURN(0);
1708 }
1709
1710 static void gss_free_reqbuf(struct ptlrpc_sec *sec,
1711                             struct ptlrpc_request *req)
1712 {
1713         char *buf;
1714         int privacy;
1715         ENTRY;
1716
1717         LASSERT(req->rq_reqmsg);
1718         LASSERT(req->rq_reqlen);
1719
1720         privacy = sec->ps_sectype == PTLRPC_SEC_TYPE_PRIV;
1721         if (privacy) {
1722                 buf = (char *) req->rq_reqmsg - GSS_PRIVBUF_PREFIX_LEN;
1723                 LASSERT(buf < req->rq_reqbuf ||
1724                         buf >= req->rq_reqbuf + req->rq_reqbuf_len);
1725                 OBD_FREE(buf, req->rq_reqlen + GSS_PRIVBUF_PREFIX_LEN +
1726                               GSS_PRIVBUF_SUFFIX_LEN);
1727                 req->rq_reqmsg = NULL;
1728         }
1729
1730         sec_free_reqbuf(sec, req);
1731 }
1732
1733 static struct ptlrpc_secops gss_secops = {
1734         .create_sec             = gss_create_sec,
1735         .destroy_sec            = gss_destroy_sec,
1736         .create_cred            = gss_create_cred,
1737         .est_req_payload        = gss_estimate_payload,
1738         .est_rep_payload        = gss_estimate_payload,
1739         .alloc_reqbuf           = gss_alloc_reqbuf,
1740         .free_reqbuf            = gss_free_reqbuf,
1741 };
1742
1743 static struct ptlrpc_sec_type gss_type = {
1744         .pst_owner      = THIS_MODULE,
1745         .pst_name       = "GSS_SEC",
1746         .pst_inst       = ATOMIC_INIT(0),
1747         .pst_flavor     = {PTLRPC_SEC_GSS, 0},
1748         .pst_ops        = &gss_secops,
1749 };
1750
1751 extern int
1752 (*lustre_secinit_downcall_handler)(char *buffer, unsigned long count);
1753
1754 int __init ptlrpcs_gss_init(void)
1755 {
1756         int rc;
1757
1758         rc = ptlrpcs_register(&gss_type);
1759         if (rc)
1760                 return rc;
1761
1762 #ifdef __KERNEL__
1763         gss_svc_init();
1764
1765         rc = PTR_ERR(rpc_mkdir(LUSTRE_PIPEDIR, NULL));
1766         if (IS_ERR((void *)rc) && rc != -EEXIST) {
1767                 CERROR("fail to make rpcpipedir for lustre\n");
1768                 gss_svc_exit();
1769                 ptlrpcs_unregister(&gss_type);
1770                 return -1;
1771         }
1772         rc = 0;
1773 #else
1774 #endif
1775         rc = init_kerberos_module();
1776         if (rc) {
1777                 ptlrpcs_unregister(&gss_type);
1778         }
1779
1780         lustre_secinit_downcall_handler = gss_send_secinit_rpc;
1781
1782         return rc;
1783 }
1784
1785 #ifdef __KERNEL__
1786 static void __exit ptlrpcs_gss_exit(void)
1787 {
1788         lustre_secinit_downcall_handler = NULL;
1789
1790         cleanup_kerberos_module();
1791         rpc_rmdir(LUSTRE_PIPEDIR);
1792         gss_svc_exit();
1793         ptlrpcs_unregister(&gss_type);
1794 }
1795 #endif
1796
1797 MODULE_AUTHOR("Cluster File Systems, Inc. <info@clusterfs.com>");
1798 MODULE_DESCRIPTION("GSS Security module for Lustre");
1799 MODULE_LICENSE("GPL");
1800
1801 module_init(ptlrpcs_gss_init);
1802 module_exit(ptlrpcs_gss_exit);