Whamcloud - gitweb
gss: be albe to parse signed int when lsvcgssd send error downcall; some
[fs/lustre-release.git] / lustre / sec / gss / svcsec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * Neil Brown <neilb@cse.unsw.edu.au>
12  * J. Bruce Fields <bfields@umich.edu>
13  * Andy Adamson <andros@umich.edu>
14  * Dug Song <dugsong@monkey.org>
15  *
16  * RPCSEC_GSS server authentication.
17  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
18  * (gssapi)
19  *
20  * The RPCSEC_GSS involves three stages:
21  *  1/ context creation
22  *  2/ data exchange
23  *  3/ context destruction
24  *
25  * Context creation is handled largely by upcalls to user-space.
26  *  In particular, GSS_Accept_sec_context is handled by an upcall
27  * Data exchange is handled entirely within the kernel
28  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
29  * Context destruction is handled in-kernel
30  *  GSS_Delete_sec_context is in-kernel
31  *
32  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
33  * The context handle and gss_token are used as a key into the rpcsec_init cache.
34  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
35  * being major_status, minor_status, context_handle, reply_token.
36  * These are sent back to the client.
37  * Sequence window management is handled by the kernel.  The window size if currently
38  * a compile time constant.
39  *
40  * When user-space is happy that a context is established, it places an entry
41  * in the rpcsec_context cache. The key for this cache is the context_handle.
42  * The content includes:
43  *   uid/gidlist - for determining access rights
44  *   mechanism type
45  *   mechanism specific information, such as a key
46  *
47  */
48
49 #define DEBUG_SUBSYSTEM S_SEC
50 #ifdef __KERNEL__
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/slab.h>
55 #include <linux/hash.h>
56 #else
57 #include <liblustre.h>
58 #endif
59
60 #include <linux/sunrpc/cache.h>
61
62 #include <libcfs/kp30.h>
63 #include <linux/obd.h>
64 #include <linux/obd_class.h>
65 #include <linux/obd_support.h>
66 #include <linux/lustre_idl.h>
67 #include <linux/lustre_net.h>
68 #include <linux/lustre_import.h>
69 #include <linux/lustre_sec.h>
70                                                                                                                         
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74
75 static inline unsigned long hash_mem(char *buf, int length, int bits)
76 {
77         unsigned long hash = 0;
78         unsigned long l = 0;
79         int len = 0;
80         unsigned char c;
81         do {
82                 if (len == length) {
83                         c = (char)len; len = -1;
84                 } else
85                         c = *buf++;
86                 l = (l << 8) | c;
87                 len++;
88                 if ((len & (BITS_PER_LONG/8-1))==0)
89                         hash = hash_long(hash^l, BITS_PER_LONG);
90         } while (len);
91         return hash >> (BITS_PER_LONG - bits);
92 }
93
94 /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
95  * into replies.
96  *
97  * Key is context handle (\x if empty) and gss_token.
98  * Content is major_status minor_status (integers) context_handle, reply_token.
99  *
100  */
101
102 #define RSI_HASHBITS    6
103 #define RSI_HASHMAX     (1<<RSI_HASHBITS)
104 #define RSI_HASHMASK    (RSI_HASHMAX-1)
105
106 struct rsi {
107         struct cache_head       h;
108         __u32                   naltype;
109         __u32                   netid;
110         __u64                   nid;
111         rawobj_t                in_handle, in_token;
112         rawobj_t                out_handle, out_token;
113         int                     major_status, minor_status;
114 };
115
116 static struct cache_head *rsi_table[RSI_HASHMAX];
117 static struct cache_detail rsi_cache;
118
119 static void rsi_free(struct rsi *rsii)
120 {
121         rawobj_free(&rsii->in_handle);
122         rawobj_free(&rsii->in_token);
123         rawobj_free(&rsii->out_handle);
124         rawobj_free(&rsii->out_token);
125 }
126
127 static void rsi_put(struct cache_head *item, struct cache_detail *cd)
128 {
129         struct rsi *rsii = container_of(item, struct rsi, h);
130         LASSERT(atomic_read(&item->refcnt) > 0);
131         if (cache_put(item, cd)) {
132                 LASSERT(item->next == NULL);
133                 rsi_free(rsii);
134                 OBD_FREE(rsii, sizeof(*rsii));
135         }
136 }
137
138 static inline int rsi_hash(struct rsi *item)
139 {
140         return hash_mem((char *)item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
141                 ^ hash_mem((char *)item->in_token.data, item->in_token.len, RSI_HASHBITS);
142 }
143
144 static inline int rsi_match(struct rsi *item, struct rsi *tmp)
145 {
146         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
147                 rawobj_equal(&item->in_token, &tmp->in_token));
148 }
149
150 static void rsi_request(struct cache_detail *cd,
151                         struct cache_head *h,
152                         char **bpp, int *blen)
153 {
154         struct rsi *rsii = container_of(h, struct rsi, h);
155
156         qword_addhex(bpp, blen, (char *)&rsii->naltype, sizeof(rsii->naltype));
157         qword_addhex(bpp, blen, (char *)&rsii->netid, sizeof(rsii->netid));
158         qword_addhex(bpp, blen, (char *)&rsii->nid, sizeof(rsii->nid));
159         qword_addhex(bpp, blen, (char *)rsii->in_handle.data, rsii->in_handle.len);
160         qword_addhex(bpp, blen, (char *)rsii->in_token.data, rsii->in_token.len);
161         (*bpp)[-1] = '\n';
162 }
163
164 static int
165 gssd_reply(struct rsi *item)
166 {
167         struct rsi *tmp;
168         struct cache_head **hp, **head;
169         ENTRY;
170
171         head = &rsi_cache.hash_table[rsi_hash(item)];
172         write_lock(&rsi_cache.hash_lock);
173         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
174                 tmp = container_of(*hp, struct rsi, h);
175                 if (rsi_match(tmp, item)) {
176                         cache_get(&tmp->h);
177                         clear_bit(CACHE_HASHED, &tmp->h.flags);
178                         *hp = tmp->h.next;
179                         tmp->h.next = NULL;
180                         rsi_cache.entries--;
181                         if (test_bit(CACHE_VALID, &tmp->h.flags)) {
182                                 CERROR("rsi is valid\n");
183                                 write_unlock(&rsi_cache.hash_lock);
184                                 rsi_put(&tmp->h, &rsi_cache);
185                                 RETURN(-EINVAL);
186                         }
187                         set_bit(CACHE_HASHED, &item->h.flags);
188                         item->h.next = *hp;
189                         *hp = &item->h;
190                         rsi_cache.entries++;
191                         set_bit(CACHE_VALID, &item->h.flags);
192                         item->h.last_refresh = get_seconds();
193                         write_unlock(&rsi_cache.hash_lock);
194                         cache_fresh(&rsi_cache, &tmp->h, 0);
195                         rsi_put(&tmp->h, &rsi_cache);
196                         RETURN(0);
197                 }
198         }
199         write_unlock(&rsi_cache.hash_lock);
200         RETURN(-EINVAL);
201 }
202
203 /* XXX
204  * here we just wait here for its completion or timedout. it's a
205  * hacking but works, and we'll comeup with real fix if we decided
206  * to still stick with NFS4 cache code
207  */
208 static struct rsi *
209 gssd_upcall(struct rsi *item, struct cache_req *chandle)
210 {
211         struct rsi *tmp;
212         struct cache_head **hp, **head;
213         unsigned long starttime;
214         ENTRY;
215
216         head = &rsi_cache.hash_table[rsi_hash(item)];
217         read_lock(&rsi_cache.hash_lock);
218         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
219                 tmp = container_of(*hp, struct rsi, h);
220                 if (rsi_match(tmp, item)) {
221                         LBUG();
222                         if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
223                                 CERROR("found rsi without VALID\n");
224                                 read_unlock(&rsi_cache.hash_lock);
225                                 return NULL;
226                         }
227                         *hp = tmp->h.next;
228                         tmp->h.next = NULL;
229                         rsi_cache.entries--;
230                         cache_get(&tmp->h);
231                         read_unlock(&rsi_cache.hash_lock);
232                         return tmp;
233                 }
234         }
235         cache_get(&item->h);
236         //set_bit(CACHE_HASHED, &item->h.flags);
237         item->h.next = *head;
238         *head = &item->h;
239         rsi_cache.entries++;
240         read_unlock(&rsi_cache.hash_lock);
241         //cache_get(&item->h);
242
243         cache_check(&rsi_cache, &item->h, chandle);
244         starttime = get_seconds();
245         do {
246                 set_current_state(TASK_UNINTERRUPTIBLE);
247                 schedule_timeout(HZ/2);
248                 read_lock(&rsi_cache.hash_lock);
249                 for (hp = head; *hp != NULL; hp = &tmp->h.next) {
250                         tmp = container_of(*hp, struct rsi, h);
251                         if (tmp == item)
252                                 continue;
253                         if (rsi_match(tmp, item)) {
254                                 if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
255                                         read_unlock(&rsi_cache.hash_lock);
256                                         return NULL;
257                                 }
258                                 cache_get(&tmp->h);
259                                 clear_bit(CACHE_HASHED, &tmp->h.flags);
260                                 *hp = tmp->h.next;
261                                 tmp->h.next = NULL;
262                                 rsi_cache.entries--;
263                                 read_unlock(&rsi_cache.hash_lock);
264                                 return tmp;
265                         }
266                 }
267                 read_unlock(&rsi_cache.hash_lock);
268         } while ((get_seconds() - starttime) <= 15);
269         CERROR("15s timeout while waiting cache refill\n");
270         return NULL;
271 }
272
273 static int rsi_parse(struct cache_detail *cd,
274                      char *mesg, int mlen)
275 {
276         /* context token expiry major minor context token */
277         char *buf = mesg;
278         char *ep;
279         int len;
280         struct rsi *rsii;
281         time_t expiry;
282         int status = -EINVAL;
283         ENTRY;
284
285         OBD_ALLOC(rsii, sizeof(*rsii));
286         if (!rsii)
287                 RETURN(-ENOMEM);
288         cache_init(&rsii->h);
289
290         /* handle */
291         len = qword_get(&mesg, buf, mlen);
292         if (len < 0)
293                 goto out;
294         if (rawobj_alloc(&rsii->in_handle, buf, len)) {
295                 status = -ENOMEM;
296                 goto out;
297         }
298
299         /* token */
300         len = qword_get(&mesg, buf, mlen);
301         if (len < 0)
302                 goto out;
303         if (rawobj_alloc(&rsii->in_token, buf, len)) {
304                 status = -ENOMEM;
305                 goto out;
306         }
307
308         /* expiry */
309         expiry = get_expiry(&mesg);
310         if (expiry == 0)
311                 goto out;
312
313         /* major */
314         len = qword_get(&mesg, buf, mlen);
315         if (len <= 0)
316                 goto out;
317         rsii->major_status = simple_strtol(buf, &ep, 10);
318         if (*ep)
319                 goto out;
320
321         /* minor */
322         len = qword_get(&mesg, buf, mlen);
323         if (len <= 0)
324                 goto out;
325         rsii->minor_status = simple_strtol(buf, &ep, 10);
326         if (*ep)
327                 goto out;
328
329         /* out_handle */
330         len = qword_get(&mesg, buf, mlen);
331         if (len < 0)
332                 goto out;
333         if (rawobj_alloc(&rsii->out_handle, buf, len)) {
334                 status = -ENOMEM;
335                 goto out;
336         }
337
338         /* out_token */
339         len = qword_get(&mesg, buf, mlen);
340         if (len < 0)
341                 goto out;
342         if (rawobj_alloc(&rsii->out_token, buf, len)) {
343                 status = -ENOMEM;
344                 goto out;
345         }
346
347         rsii->h.expiry_time = expiry;
348         status = gssd_reply(rsii);
349 out:
350         if (rsii)
351                 rsi_put(&rsii->h, &rsi_cache);
352         RETURN(status);
353 }
354
355 static struct cache_detail rsi_cache = {
356         .hash_size      = RSI_HASHMAX,
357         .hash_table     = rsi_table,
358         .name           = "auth.ptlrpcs.init",
359         .cache_put      = rsi_put,
360         .cache_request  = rsi_request,
361         .cache_parse    = rsi_parse,
362 };
363
364 /*
365  * The rpcsec_context cache is used to store a context that is
366  * used in data exchange.
367  * The key is a context handle. The content is:
368  *  uid, gidlist, mechanism, service-set, mech-specific-data
369  */
370
371 #define RSC_HASHBITS    10
372 #define RSC_HASHMAX     (1<<RSC_HASHBITS)
373 #define RSC_HASHMASK    (RSC_HASHMAX-1)
374
375 #define GSS_SEQ_WIN     512
376
377 struct gss_svc_seq_data {
378         /* highest seq number seen so far: */
379         __u32                   sd_max;
380         /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
381          * sd_win is nonzero iff sequence number i has been seen already: */
382         unsigned long           sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
383         spinlock_t              sd_lock;
384 };
385
386 struct rsc {
387         struct cache_head       h;
388         rawobj_t                handle;
389         __u32                   remote_realm;
390         struct vfs_cred         cred;
391         uid_t                   mapped_uid;
392         struct gss_svc_seq_data seqdata;
393         struct gss_ctx         *mechctx;
394 };
395
396 static struct cache_head *rsc_table[RSC_HASHMAX];
397 static struct cache_detail rsc_cache;
398
399 static void rsc_free(struct rsc *rsci)
400 {
401         rawobj_free(&rsci->handle);
402         if (rsci->mechctx)
403                 kgss_delete_sec_context(&rsci->mechctx);
404 #if 0
405         if (rsci->cred.vc_ginfo)
406                 put_group_info(rsci->cred.vc_ginfo);
407 #endif
408 }
409
410 static void rsc_put(struct cache_head *item, struct cache_detail *cd)
411 {
412         struct rsc *rsci = container_of(item, struct rsc, h);
413
414         LASSERT(atomic_read(&item->refcnt) > 0);
415         if (cache_put(item, cd)) {
416                 LASSERT(item->next == NULL);
417                 rsc_free(rsci);
418                 OBD_FREE(rsci, sizeof(*rsci));
419         }
420 }
421
422 static inline int
423 rsc_hash(struct rsc *rsci)
424 {
425         return hash_mem((char *)rsci->handle.data,
426                         rsci->handle.len, RSC_HASHBITS);
427 }
428
429 static inline int
430 rsc_match(struct rsc *new, struct rsc *tmp)
431 {
432         return rawobj_equal(&new->handle, &tmp->handle);
433 }
434
435 static struct rsc *rsc_lookup(struct rsc *item, int set)
436 {
437         struct rsc *tmp = NULL;
438         struct cache_head **hp, **head;
439         head = &rsc_cache.hash_table[rsc_hash(item)];
440         ENTRY;
441
442         if (set)
443                 write_lock(&rsc_cache.hash_lock);
444         else
445                 read_lock(&rsc_cache.hash_lock);
446         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
447                 tmp = container_of(*hp, struct rsc, h);
448                 if (!rsc_match(tmp, item))
449                         continue;
450                 cache_get(&tmp->h);
451                 if (!set)
452                         goto out_noset;
453                 *hp = tmp->h.next;
454                 tmp->h.next = NULL;
455                 clear_bit(CACHE_HASHED, &tmp->h.flags);
456                 rsc_put(&tmp->h, &rsc_cache);
457                 goto out_set;
458         }
459         /* Didn't find anything */
460         if (!set)
461                 goto out_nada;
462         rsc_cache.entries++;
463 out_set:
464         set_bit(CACHE_HASHED, &item->h.flags);
465         item->h.next = *head;
466         *head = &item->h;
467         write_unlock(&rsc_cache.hash_lock);
468         cache_fresh(&rsc_cache, &item->h, item->h.expiry_time);
469         cache_get(&item->h);
470         RETURN(item);
471 out_nada:
472         tmp = NULL;
473 out_noset:
474         read_unlock(&rsc_cache.hash_lock);
475         RETURN(tmp);
476 }
477                                                                                                                         
478 static int rsc_parse(struct cache_detail *cd,
479                      char *mesg, int mlen)
480 {
481         /* contexthandle expiry [ uid gid N <n gids> mechname
482          * ...mechdata... ] */
483         char *buf = mesg;
484         int len, rv;
485         struct rsc *rsci, *res = NULL;
486         time_t expiry;
487         int status = -EINVAL;
488
489         OBD_ALLOC(rsci, sizeof(*rsci));
490         if (!rsci) {
491                 CERROR("fail to alloc rsci\n");
492                 return -ENOMEM;
493         }
494         cache_init(&rsci->h);
495
496         /* context handle */
497         len = qword_get(&mesg, buf, mlen);
498         if (len < 0) goto out;
499         status = -ENOMEM;
500         if (rawobj_alloc(&rsci->handle, buf, len))
501                 goto out;
502
503         /* expiry */
504         expiry = get_expiry(&mesg);
505         status = -EINVAL;
506         if (expiry == 0)
507                 goto out;
508
509         /* remote flag */
510         rv = get_int(&mesg, (int *)&rsci->remote_realm);
511         if (rv) {
512                 CERROR("fail to get remote flag\n");
513                 goto out;
514         }
515
516         /* mapped uid */
517         rv = get_int(&mesg, (int *)&rsci->mapped_uid);
518         if (rv) {
519                 CERROR("fail to get mapped uid\n");
520                 goto out;
521         }
522
523         /* uid, or NEGATIVE */
524         rv = get_int(&mesg, (int *)&rsci->cred.vc_uid);
525         if (rv == -EINVAL)
526                 goto out;
527         if (rv == -ENOENT) {
528                 CERROR("NOENT? set rsc entry negative\n");
529                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
530         } else {
531                 struct gss_api_mech *gm;
532                 rawobj_t tmp_buf;
533                 __u64 ctx_expiry;
534
535                 /* gid */
536                 if (get_int(&mesg, (int *)&rsci->cred.vc_gid))
537                         goto out;
538
539                 /* mech name */
540                 len = qword_get(&mesg, buf, mlen);
541                 if (len < 0)
542                         goto out;
543                 gm = kgss_name_to_mech(buf);
544                 status = -EOPNOTSUPP;
545                 if (!gm)
546                         goto out;
547
548                 status = -EINVAL;
549                 /* mech-specific data: */
550                 len = qword_get(&mesg, buf, mlen);
551                 if (len < 0) {
552                         kgss_mech_put(gm);
553                         goto out;
554                 }
555                 tmp_buf.len = len;
556                 tmp_buf.data = (unsigned char *)buf;
557                 if (kgss_import_sec_context(&tmp_buf, gm, &rsci->mechctx)) {
558                         kgss_mech_put(gm);
559                         goto out;
560                 }
561
562                 /* currently the expiry time passed down from user-space
563                  * is invalid, here we retrive it from mech.
564                  */
565                 if (kgss_inquire_context(rsci->mechctx, &ctx_expiry)) {
566                         CERROR("unable to get expire time, drop it\n");
567                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
568                         kgss_mech_put(gm);
569                         goto out;
570                 }
571                 expiry = (time_t) ctx_expiry;
572
573                 kgss_mech_put(gm);
574         }
575         rsci->h.expiry_time = expiry;
576         spin_lock_init(&rsci->seqdata.sd_lock);
577         res = rsc_lookup(rsci, 1);
578         rsc_put(&res->h, &rsc_cache);
579         status = 0;
580 out:
581         if (rsci)
582                 rsc_put(&rsci->h, &rsc_cache);
583         return status;
584 }
585
586 /*
587  * flush all entries with @uid. @uid == -1 will match all.
588  * we only know the uid, maybe netid/nid in the future, in all cases
589  * we must search the whole cache
590  */
591 static void rsc_flush(uid_t uid)
592 {
593         struct cache_head **ch;
594         struct rsc *rscp;
595         int n;
596         ENTRY;
597
598         write_lock(&rsc_cache.hash_lock);
599         for (n = 0; n < RSC_HASHMAX; n++) {
600                 for (ch = &rsc_cache.hash_table[n]; *ch;) {
601                         rscp = container_of(*ch, struct rsc, h);
602                         if (uid == -1 || rscp->cred.vc_uid == uid) {
603                                 /* it seems simply set NEGATIVE doesn't work */
604                                 *ch = (*ch)->next;
605                                 rscp->h.next = NULL;
606                                 cache_get(&rscp->h);
607                                 set_bit(CACHE_NEGATIVE, &rscp->h.flags);
608                                 clear_bit(CACHE_HASHED, &rscp->h.flags);
609                                 CDEBUG(D_SEC, "flush rsc %p for uid %u\n",
610                                        rscp, rscp->cred.vc_uid);
611                                 rsc_put(&rscp->h, &rsc_cache);
612                                 rsc_cache.entries--;
613                                 continue;
614                         }
615                         ch = &((*ch)->next);
616                 }
617         }
618         write_unlock(&rsc_cache.hash_lock);
619         EXIT;
620 }
621
622 static struct cache_detail rsc_cache = {
623         .hash_size      = RSC_HASHMAX,
624         .hash_table     = rsc_table,
625         .name           = "auth.ptlrpcs.context",
626         .cache_put      = rsc_put,
627         .cache_parse    = rsc_parse,
628 };
629
630 static struct rsc *
631 gss_svc_searchbyctx(rawobj_t *handle)
632 {
633         struct rsc rsci;
634         struct rsc *found;
635
636         rsci.handle = *handle;
637         found = rsc_lookup(&rsci, 0);
638         if (!found)
639                 return NULL;
640
641         if (cache_check(&rsc_cache, &found->h, NULL))
642                 return NULL;
643
644         return found;
645 }
646
647 struct gss_svc_data {
648         /* decoded gss client cred: */
649         struct rpc_gss_wire_cred        clcred;
650         /* internal used status */
651         unsigned int                    is_init:1,
652                                         is_init_continue:1,
653                                         is_err_notify:1,
654                                         is_fini:1;
655         int                             reserve_len;
656 };
657
658 /* FIXME
659  * again hacking: only try to give the svcgssd a chance to handle
660  * upcalls.
661  */
662 struct cache_deferred_req* my_defer(struct cache_req *req)
663 {
664         yield();
665         return NULL;
666 }
667 static struct cache_req my_chandle = {my_defer};
668
669 /* Implements sequence number algorithm as specified in RFC 2203. */
670 static int
671 gss_check_seq_num(struct gss_svc_seq_data *sd, __u32 seq_num)
672 {
673         int rc = 0;
674
675         spin_lock(&sd->sd_lock);
676         if (seq_num > sd->sd_max) {
677                 if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
678                         memset(sd->sd_win, 0, sizeof(sd->sd_win));
679                         sd->sd_max = seq_num;
680                 } else {
681                         while(sd->sd_max < seq_num) {
682                                 sd->sd_max++;
683                                 __clear_bit(sd->sd_max % GSS_SEQ_WIN,
684                                             sd->sd_win);
685                         }
686                 }
687                 __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
688                 goto exit;
689         } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
690                 CERROR("seq %u too low: max %u, win %d\n",
691                         seq_num, sd->sd_max, GSS_SEQ_WIN);
692                 rc = 1;
693                 goto exit;
694         }
695
696         if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) {
697                 CERROR("seq %u is replay: max %u, win %d\n",
698                         seq_num, sd->sd_max, GSS_SEQ_WIN);
699                 rc = 1;
700         }
701 exit:
702         spin_unlock(&sd->sd_lock);
703         return rc;
704 }
705
706 static int
707 gss_svc_verify_request(struct ptlrpc_request *req,
708                        struct rsc *rsci,
709                        struct rpc_gss_wire_cred *gc,
710                        __u32 *vp, __u32 vlen)
711 {
712         struct ptlrpcs_wire_hdr *sec_hdr;
713         struct gss_ctx *ctx = rsci->mechctx;
714         __u32 maj_stat;
715         rawobj_t msg;
716         rawobj_t mic;
717         ENTRY;
718
719         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
720
721         req->rq_reqmsg = (struct lustre_msg *) (req->rq_reqbuf + sizeof(*sec_hdr));
722         req->rq_reqlen = sec_hdr->msg_len;
723
724         msg.len = sec_hdr->msg_len;
725         msg.data = (__u8 *)req->rq_reqmsg;
726
727         mic.len = le32_to_cpu(*vp++);
728         mic.data = (unsigned char *)vp;
729         vlen -= 4;
730
731         if (mic.len > vlen) {
732                 CERROR("checksum len %d, while buffer len %d\n",
733                         mic.len, vlen);
734                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
735         }
736
737         if (mic.len > 256) {
738                 CERROR("invalid mic len %d\n", mic.len);
739                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
740         }
741
742         maj_stat = kgss_verify_mic(ctx, &msg, &mic, NULL);
743         if (maj_stat != GSS_S_COMPLETE) {
744                 CERROR("MIC verification error: major %x\n", maj_stat);
745                 RETURN(maj_stat);
746         }
747
748         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
749                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
750                         req, req->rq_reqmsg->opc, req->rq_xid,
751                         req->rq_reqmsg->transno);
752                 RETURN(GSS_S_DUPLICATE_TOKEN);
753         }
754
755         RETURN(GSS_S_COMPLETE);
756 }
757
758 static int
759 gss_svc_unseal_request(struct ptlrpc_request *req,
760                        struct rsc *rsci,
761                        struct rpc_gss_wire_cred *gc,
762                        __u32 *vp, __u32 vlen)
763 {
764         struct ptlrpcs_wire_hdr *sec_hdr;
765         struct gss_ctx *ctx = rsci->mechctx;
766         rawobj_t cipher_text, plain_text;
767         __u32 major;
768         ENTRY;
769
770         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
771
772         if (vlen < 4) {
773                 CERROR("vlen only %u\n", vlen);
774                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
775         }
776
777         cipher_text.len = le32_to_cpu(*vp++);
778         cipher_text.data = (__u8 *) vp;
779         vlen -= 4;
780         
781         if (cipher_text.len > vlen) {
782                 CERROR("cipher claimed %u while buf only %u\n",
783                         cipher_text.len, vlen);
784                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
785         }
786
787         plain_text = cipher_text;
788
789         major = kgss_unwrap(ctx, GSS_C_QOP_DEFAULT, &cipher_text, &plain_text);
790         if (major) {
791                 CERROR("unwrap error 0x%x\n", major);
792                 RETURN(major);
793         }
794
795         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
796                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
797                         req, req->rq_reqmsg->opc, req->rq_xid,
798                         req->rq_reqmsg->transno);
799                 RETURN(GSS_S_DUPLICATE_TOKEN);
800         }
801
802         req->rq_reqmsg = (struct lustre_msg *) (vp);
803         req->rq_reqlen = plain_text.len;
804
805         CDEBUG(D_SEC, "msg len %d\n", req->rq_reqlen);
806
807         RETURN(GSS_S_COMPLETE);
808 }
809
810 static int
811 gss_pack_err_notify(struct ptlrpc_request *req,
812                     __u32 major, __u32 minor)
813 {
814         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
815         __u32 reslen, *resp, *reslenp;
816         char  nidstr[PTL_NALFMT_SIZE];
817         const __u32 secdata_len = 7 * 4;
818         int rc;
819         ENTRY;
820
821         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_ERR_NOTIFY|OBD_FAIL_ONCE, -EINVAL);
822
823         LASSERT(svcdata);
824         svcdata->is_err_notify = 1;
825         svcdata->reserve_len = 7 * 4;
826
827         rc = lustre_pack_reply(req, 0, NULL, NULL);
828         if (rc) {
829                 CERROR("could not pack reply, err %d\n", rc);
830                 RETURN(rc);
831         }
832
833         LASSERT(req->rq_reply_state);
834         LASSERT(req->rq_reply_state->rs_repbuf);
835         LASSERT(req->rq_reply_state->rs_repbuf_len >= secdata_len);
836         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
837
838         /* header */
839         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
840         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
841         *resp++ = cpu_to_le32(req->rq_replen);
842         reslenp = resp++;
843
844         /* skip lustre msg */
845         resp += req->rq_replen / 4;
846         reslen = svcdata->reserve_len;
847
848         /* gss replay:
849          * version, subflavor, notify, major, minor,
850          * obj1(fake), obj2(fake)
851          */
852         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
853         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
854         *resp++ = cpu_to_le32(PTLRPC_GSS_PROC_ERR);
855         *resp++ = cpu_to_le32(major);
856         *resp++ = cpu_to_le32(minor);
857         *resp++ = 0;
858         *resp++ = 0;
859         reslen -= (4 * 4);
860         /* the actual sec data length */
861         *reslenp = cpu_to_le32(secdata_len);
862
863         req->rq_reply_state->rs_repdata_len += (secdata_len);
864         CDEBUG(D_SEC, "prepare gss error notify(0x%x/0x%x) to %s\n",
865                major, minor,
866                portals_nid2str(req->rq_peer.peer_ni->pni_number,
867                                req->rq_peer.peer_id.nid, nidstr));
868         RETURN(0);
869 }
870
871 static void dump_cache_head(struct cache_head *h)
872 {
873         CWARN("ref %d, fl %lx, n %p, t %ld, %ld\n",
874               atomic_read(&h->refcnt), h->flags, h->next,
875               h->expiry_time, h->last_refresh);
876 }
877 static void dump_rsi(struct rsi *rsi)
878 {
879         CWARN("dump rsi %p\n", rsi);
880         dump_cache_head(&rsi->h);
881         CWARN("%x,%x,%llx\n", rsi->naltype, rsi->netid, rsi->nid);
882         CWARN("len %d, d %p\n", rsi->in_handle.len, rsi->in_handle.data);
883         CWARN("len %d, d %p\n", rsi->in_token.len, rsi->in_token.data);
884         CWARN("len %d, d %p\n", rsi->out_handle.len, rsi->out_handle.data);
885         CWARN("len %d, d %p\n", rsi->out_token.len, rsi->out_token.data);
886 }
887
888 static int
889 gss_svcsec_handle_init(struct ptlrpc_request *req,
890                        struct rpc_gss_wire_cred *gc,
891                        __u32 *secdata, __u32 seclen,
892                        enum ptlrpcs_error *res)
893 {
894         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
895         struct rsc          *rsci;
896         struct rsi          *rsikey, *rsip;
897         rawobj_t             tmpobj;
898         __u32 reslen,       *resp, *reslenp;
899         char                 nidstr[PTL_NALFMT_SIZE];
900         int                  rc;
901         ENTRY;
902
903         LASSERT(svcdata);
904
905         CDEBUG(D_SEC, "processing gss init(%d) request from %s\n", gc->gc_proc,
906                portals_nid2str(req->rq_peer.peer_ni->pni_number,
907                                req->rq_peer.peer_id.nid, nidstr));
908
909         *res = PTLRPCS_BADCRED;
910         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_INIT_REQ|OBD_FAIL_ONCE, SVC_DROP);
911
912         if (gc->gc_proc == RPC_GSS_PROC_INIT &&
913             gc->gc_ctx.len != 0) {
914                 CERROR("proc %d, ctx_len %d: not really init?\n",
915                 gc->gc_proc == RPC_GSS_PROC_INIT, gc->gc_ctx.len);
916                 RETURN(SVC_DROP);
917         }
918
919         OBD_ALLOC(rsikey, sizeof(*rsikey));
920         if (!rsikey) {
921                 CERROR("out of memory\n");
922                 RETURN(SVC_DROP);
923         }
924         cache_init(&rsikey->h);
925
926         if (rawobj_dup(&rsikey->in_handle, &gc->gc_ctx)) {
927                 CERROR("fail to dup context handle\n");
928                 GOTO(out_rsikey, rc = SVC_DROP);
929         }
930         *res = PTLRPCS_BADVERF;
931         if (rawobj_extract(&tmpobj, &secdata, &seclen)) {
932                 CERROR("can't extract token\n");
933                 GOTO(out_rsikey, rc = SVC_DROP);
934         }
935         if (rawobj_dup(&rsikey->in_token, &tmpobj)) {
936                 CERROR("can't duplicate token\n");
937                 GOTO(out_rsikey, rc = SVC_DROP);
938         }
939
940         rsikey->naltype = (__u32) req->rq_peer.peer_ni->pni_number;
941         rsikey->netid = 0;
942         rsikey->nid = (__u64) req->rq_peer.peer_id.nid;
943
944         rsip = gssd_upcall(rsikey, &my_chandle);
945         if (!rsip) {
946                 CERROR("error in gssd_upcall.\n");
947
948                 rc = SVC_COMPLETE;
949                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
950                         rc = SVC_DROP;
951
952                 GOTO(out_rsikey, rc);
953         }
954
955         rsci = gss_svc_searchbyctx(&rsip->out_handle);
956         if (!rsci) {
957                 CERROR("rsci still not mature yet?\n");
958
959                 rc = SVC_COMPLETE;
960                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
961                         rc = SVC_DROP;
962
963                 GOTO(out_rsip, rc);
964         }
965         CDEBUG(D_SEC, "svcsec create gss context %p(%u@%s)\n",
966                rsci, rsci->cred.vc_uid,
967                portals_nid2str(req->rq_peer.peer_ni->pni_number,
968                                req->rq_peer.peer_id.nid, nidstr));
969
970         svcdata->is_init = 1;
971         svcdata->reserve_len = 6 * 4 +
972                 size_round4(rsip->out_handle.len) +
973                 size_round4(rsip->out_token.len);
974
975         rc = lustre_pack_reply(req, 0, NULL, NULL);
976         if (rc) {
977                 CERROR("failed to pack reply, rc = %d\n", rc);
978                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
979                 GOTO(out, rc = SVC_DROP);
980         }
981
982         /* header */
983         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
984         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
985         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
986         *resp++ = cpu_to_le32(req->rq_replen);
987         reslenp = resp++;
988
989         resp += req->rq_replen / 4;
990         reslen = svcdata->reserve_len;
991
992         /* gss reply:
993          * status, major, minor, seq, out_handle, out_token
994          */
995         *resp++ = cpu_to_le32(PTLRPCS_OK);
996         *resp++ = cpu_to_le32(rsip->major_status);
997         *resp++ = cpu_to_le32(rsip->minor_status);
998         *resp++ = cpu_to_le32(GSS_SEQ_WIN);
999         reslen -= (4 * 4);
1000         if (rawobj_serialize(&rsip->out_handle,
1001                              &resp, &reslen)) {
1002                 dump_rsi(rsip);
1003                 dump_rsi(rsikey);
1004                 LBUG();
1005         }
1006         if (rawobj_serialize(&rsip->out_token,
1007                              &resp, &reslen)) {
1008                 dump_rsi(rsip);
1009                 dump_rsi(rsikey);
1010                 LBUG();
1011         }
1012         /* the actual sec data length */
1013         *reslenp = cpu_to_le32(svcdata->reserve_len - reslen);
1014
1015         req->rq_reply_state->rs_repdata_len += le32_to_cpu(*reslenp);
1016         CDEBUG(D_SEC, "req %p: msgsize %d, authsize %d, "
1017                "total size %d\n", req, req->rq_replen,
1018                le32_to_cpu(*reslenp),
1019                req->rq_reply_state->rs_repdata_len);
1020
1021         *res = PTLRPCS_OK;
1022
1023         req->rq_auth_uid = rsci->cred.vc_uid;
1024         req->rq_remote_realm = rsci->remote_realm;
1025         req->rq_mapped_uid = rsci->mapped_uid;
1026
1027         /* This is simplified since right now we doesn't support
1028          * INIT_CONTINUE yet.
1029          */
1030         if (gc->gc_proc == RPC_GSS_PROC_INIT) {
1031                 struct ptlrpcs_wire_hdr *hdr;
1032
1033                 hdr = buf_to_sec_hdr(req->rq_reqbuf);
1034                 req->rq_reqmsg = buf_to_lustre_msg(req->rq_reqbuf);
1035                 req->rq_reqlen = hdr->msg_len;
1036
1037                 rc = SVC_LOGIN;
1038         } else
1039                 rc = SVC_COMPLETE;
1040
1041 out:
1042         rsc_put(&rsci->h, &rsc_cache);
1043 out_rsip:
1044         rsi_put(&rsip->h, &rsi_cache);
1045 out_rsikey:
1046         rsi_put(&rsikey->h, &rsi_cache);
1047
1048         RETURN(rc);
1049 }
1050
1051 static int
1052 gss_svcsec_handle_data(struct ptlrpc_request *req,
1053                        struct rpc_gss_wire_cred *gc,
1054                        __u32 *secdata, __u32 seclen,
1055                        enum ptlrpcs_error *res)
1056 {
1057         struct rsc          *rsci;
1058         char                 nidstr[PTL_NALFMT_SIZE];
1059         __u32                major;
1060         int                  rc;
1061         ENTRY;
1062
1063         *res = PTLRPCS_GSS_CREDPROBLEM;
1064
1065         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1066         if (!rsci) {
1067                 CWARN("Invalid gss context handle from %s\n",
1068                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1069                                        req->rq_peer.peer_id.nid, nidstr));
1070                 major = GSS_S_NO_CONTEXT;
1071                 goto notify_err;
1072         }
1073
1074         switch (gc->gc_svc) {
1075         case PTLRPC_GSS_SVC_INTEGRITY:
1076                 major = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1077                 if (major == GSS_S_COMPLETE)
1078                         break;
1079
1080                 CWARN("fail in verify:0x%x: ctx %p@%s\n", major, rsci,
1081                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1082                                        req->rq_peer.peer_id.nid, nidstr));
1083                 goto notify_err;
1084         case PTLRPC_GSS_SVC_PRIVACY:
1085                 major = gss_svc_unseal_request(req, rsci, gc, secdata, seclen);
1086                 if (major == GSS_S_COMPLETE)
1087                         break;
1088
1089                 CWARN("fail in decrypt:0x%x: ctx %p@%s\n", major, rsci,
1090                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1091                                        req->rq_peer.peer_id.nid, nidstr));
1092                 goto notify_err;
1093         default:
1094                 CERROR("unsupported gss service %d\n", gc->gc_svc);
1095                 GOTO(out, rc = SVC_DROP);
1096         }
1097
1098         req->rq_auth_uid = rsci->cred.vc_uid;
1099         req->rq_remote_realm = rsci->remote_realm;
1100         req->rq_mapped_uid = rsci->mapped_uid;
1101
1102         *res = PTLRPCS_OK;
1103         GOTO(out, rc = SVC_OK);
1104
1105 notify_err:
1106         if (gss_pack_err_notify(req, major, 0))
1107                 rc = SVC_DROP;
1108         else
1109                 rc = SVC_COMPLETE;
1110 out:
1111         if (rsci)
1112                 rsc_put(&rsci->h, &rsc_cache);
1113         RETURN(rc);
1114 }
1115
1116 static int
1117 gss_svcsec_handle_destroy(struct ptlrpc_request *req,
1118                           struct rpc_gss_wire_cred *gc,
1119                           __u32 *secdata, __u32 seclen,
1120                           enum ptlrpcs_error *res)
1121 {
1122         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1123         struct rsc          *rsci;
1124         char                 nidstr[PTL_NALFMT_SIZE];
1125         int                  rc;
1126         ENTRY;
1127
1128         LASSERT(svcdata);
1129         *res = PTLRPCS_GSS_CREDPROBLEM;
1130
1131         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1132         if (!rsci) {
1133                 CWARN("invalid gss context handle for destroy.\n");
1134                 RETURN(SVC_DROP);
1135         }
1136
1137         if (gc->gc_svc != PTLRPC_GSS_SVC_INTEGRITY) {
1138                 CERROR("service %d is not supported in destroy.\n",
1139                         gc->gc_svc);
1140                 GOTO(out, rc = SVC_DROP);
1141         }
1142
1143         *res = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1144         if (*res)
1145                 GOTO(out, rc = SVC_DROP);
1146
1147         /* compose reply, which is actually nothing */
1148         svcdata->is_fini = 1;
1149         if (lustre_pack_reply(req, 0, NULL, NULL))
1150                 GOTO(out, rc = SVC_DROP);
1151
1152         CDEBUG(D_SEC, "svcsec destroy gss context %p(%u@%s)\n",
1153                rsci, rsci->cred.vc_uid,
1154                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1155                                req->rq_peer.peer_id.nid, nidstr));
1156
1157         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1158         *res = PTLRPCS_OK;
1159         rc = SVC_LOGOUT;
1160 out:
1161         rsc_put(&rsci->h, &rsc_cache);
1162         RETURN(rc);
1163 }
1164
1165 /*
1166  * let incomming request go through security check:
1167  *  o context establishment: invoke user space helper
1168  *  o data exchange: verify/decrypt
1169  *  o context destruction: mark context invalid
1170  *
1171  * in most cases, error will result to drop the packet silently.
1172  */
1173 static int
1174 gss_svcsec_accept(struct ptlrpc_request *req, enum ptlrpcs_error *res)
1175 {
1176         struct gss_svc_data *svcdata;
1177         struct rpc_gss_wire_cred *gc;
1178         struct ptlrpcs_wire_hdr *sec_hdr;
1179         __u32 seclen, *secdata, version, subflavor;
1180         int rc;
1181         ENTRY;
1182
1183         CDEBUG(D_SEC, "request %p\n", req);
1184         LASSERT(req->rq_reqbuf);
1185         LASSERT(req->rq_reqbuf_len);
1186
1187         *res = PTLRPCS_BADCRED;
1188
1189         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1190         LASSERT(sec_hdr->flavor == PTLRPC_SEC_GSS);
1191
1192         seclen = req->rq_reqbuf_len - sizeof(*sec_hdr) - sec_hdr->msg_len;
1193         secdata = (__u32 *) buf_to_sec_data(req->rq_reqbuf);
1194
1195         if (sec_hdr->sec_len > seclen) {
1196                 CERROR("seclen %d, while max buf %d\n",
1197                         sec_hdr->sec_len, seclen);
1198                 RETURN(SVC_DROP);
1199         }
1200
1201         if (seclen < 6 * 4) {
1202                 CERROR("sec size %d too small\n", seclen);
1203                 RETURN(SVC_DROP);
1204         }
1205
1206         LASSERT(!req->rq_sec_svcdata);
1207         OBD_ALLOC(svcdata, sizeof(*svcdata));
1208         if (!svcdata) {
1209                 CERROR("fail to alloc svcdata\n");
1210                 RETURN(SVC_DROP);
1211         }
1212         req->rq_sec_svcdata = svcdata;
1213         gc = &svcdata->clcred;
1214
1215         /* Now secdata/seclen is what we want to parse
1216          */
1217         version = le32_to_cpu(*secdata++);      /* version */
1218         subflavor = le32_to_cpu(*secdata++);    /* subflavor */
1219         gc->gc_proc = le32_to_cpu(*secdata++);  /* proc */
1220         gc->gc_seq = le32_to_cpu(*secdata++);   /* seq */
1221         gc->gc_svc = le32_to_cpu(*secdata++);   /* service */
1222         seclen -= 5 * 4;
1223
1224         CDEBUG(D_SEC, "wire gss_hdr: %u/%u/%u/%u/%u\n",
1225                version, subflavor, gc->gc_proc, gc->gc_seq, gc->gc_svc);
1226
1227         if (version != PTLRPC_SEC_GSS_VERSION) {
1228                 CERROR("gss version mismatch: %d - %d\n",
1229                         version, PTLRPC_SEC_GSS_VERSION);
1230                 GOTO(err_free, rc = SVC_DROP);
1231         }
1232
1233         if (rawobj_extract(&gc->gc_ctx, &secdata, &seclen)) {
1234                 CERROR("fail to obtain gss context handle\n");
1235                 GOTO(err_free, rc = SVC_DROP);
1236         }
1237
1238         *res = PTLRPCS_BADVERF;
1239         switch(gc->gc_proc) {
1240         case RPC_GSS_PROC_INIT:
1241         case RPC_GSS_PROC_CONTINUE_INIT:
1242                 rc = gss_svcsec_handle_init(req, gc, secdata, seclen, res);
1243                 break;
1244         case RPC_GSS_PROC_DATA:
1245                 rc = gss_svcsec_handle_data(req, gc, secdata, seclen, res);
1246                 break;
1247         case RPC_GSS_PROC_DESTROY:
1248                 rc = gss_svcsec_handle_destroy(req, gc, secdata, seclen, res);
1249                 break;
1250         default:
1251                 rc = SVC_DROP;
1252                 LBUG();
1253         }
1254
1255 err_free:
1256         if (rc == SVC_DROP && req->rq_sec_svcdata) {
1257                 OBD_FREE(req->rq_sec_svcdata, sizeof(struct gss_svc_data));
1258                 req->rq_sec_svcdata = NULL;
1259         }
1260
1261         RETURN(rc);
1262 }
1263
1264 static int
1265 gss_svcsec_authorize(struct ptlrpc_request *req)
1266 {
1267         struct ptlrpc_reply_state *rs = req->rq_reply_state;
1268         struct gss_svc_data *gsd = (struct gss_svc_data *)req->rq_sec_svcdata;
1269         struct rpc_gss_wire_cred  *gc = &gsd->clcred;
1270         struct rsc                *rscp;
1271         struct ptlrpcs_wire_hdr   *sec_hdr;
1272         rawobj_buf_t               msg_buf;
1273         rawobj_t                   cipher_buf;
1274         __u32                     *vp, *vpsave, major, vlen, seclen;
1275         rawobj_t                   lmsg, mic;
1276         int                        ret;
1277         ENTRY;
1278
1279         LASSERT(rs);
1280         LASSERT(rs->rs_repbuf);
1281         LASSERT(gsd);
1282
1283         if (gsd->is_init || gsd->is_init_continue ||
1284             gsd->is_err_notify || gsd->is_fini) {
1285                 /* nothing to do in these cases */
1286                 CDEBUG(D_SEC, "req %p: init/fini/err\n", req);
1287                 RETURN(0);
1288         }
1289
1290         if (gc->gc_proc != RPC_GSS_PROC_DATA) {
1291                 CERROR("proc %d not support\n", gc->gc_proc);
1292                 RETURN(-EINVAL);
1293         }
1294
1295         rscp = gss_svc_searchbyctx(&gc->gc_ctx);
1296         if (!rscp) {
1297                 CERROR("ctx disapeared under us?\n");
1298                 RETURN(-EINVAL);
1299         }
1300
1301         sec_hdr = (struct ptlrpcs_wire_hdr *) rs->rs_repbuf;
1302         switch (gc->gc_svc) {
1303         case  PTLRPC_GSS_SVC_INTEGRITY:
1304                 /* prepare various pointers */
1305                 lmsg.len = req->rq_replen;
1306                 lmsg.data = (__u8 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1307                 vp = (__u32 *) (lmsg.data + lmsg.len);
1308                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr) - lmsg.len;
1309                 seclen = vlen;
1310
1311                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1312                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_AUTH);
1313                 sec_hdr->msg_len = cpu_to_le32(req->rq_replen);
1314
1315                 /* standard gss hdr */
1316                 LASSERT(vlen >= 7 * 4);
1317                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1318                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1319                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1320                 *vp++ = cpu_to_le32(gc->gc_seq);
1321                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_INTEGRITY);
1322                 *vp++ = 0;      /* fake ctx handle */
1323                 vpsave = vp++;  /* reserve size */
1324                 vlen -= 7 * 4;
1325
1326                 mic.len = vlen;
1327                 mic.data = (unsigned char *)vp;
1328
1329                 major = kgss_get_mic(rscp->mechctx, 0, &lmsg, &mic);
1330                 if (major) {
1331                         CERROR("fail to get MIC: 0x%x\n", major);
1332                         GOTO(out, ret = -EINVAL);
1333                 }
1334                 *vpsave = cpu_to_le32(mic.len);
1335                 seclen = seclen - vlen + mic.len;
1336                 sec_hdr->sec_len = cpu_to_le32(seclen);
1337                 rs->rs_repdata_len += size_round(seclen);
1338                 break;
1339         case  PTLRPC_GSS_SVC_PRIVACY:
1340                 vp = (__u32 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1341                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr);
1342                 seclen = vlen;
1343
1344                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1345                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_PRIV);
1346                 sec_hdr->msg_len = cpu_to_le32(0);
1347
1348                 /* standard gss hdr */
1349                 LASSERT(vlen >= 7 * 4);
1350                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1351                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1352                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1353                 *vp++ = cpu_to_le32(gc->gc_seq);
1354                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_PRIVACY);
1355                 *vp++ = 0;      /* fake ctx handle */
1356                 vpsave = vp++;  /* reserve size */
1357                 vlen -= 7 * 4;
1358
1359                 msg_buf.buf = (__u8 *) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1360                 msg_buf.buflen = req->rq_replen + GSS_PRIVBUF_PREFIX_LEN +
1361                                  GSS_PRIVBUF_SUFFIX_LEN;
1362                 msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1363                 msg_buf.datalen = req->rq_replen;
1364
1365                 cipher_buf.data = (__u8 *) vp;
1366                 cipher_buf.len = vlen;
1367
1368                 major = kgss_wrap(rscp->mechctx, GSS_C_QOP_DEFAULT,
1369                                 &msg_buf, &cipher_buf);
1370                 if (major) {
1371                         CERROR("failed to wrap: 0x%x\n", major);
1372                         GOTO(out, ret = -EINVAL);
1373                 }
1374
1375                 *vpsave = cpu_to_le32(cipher_buf.len);
1376                 seclen = seclen - vlen + cipher_buf.len;
1377                 sec_hdr->sec_len = cpu_to_le32(seclen);
1378                 rs->rs_repdata_len += size_round(seclen);
1379                 break;
1380         default:
1381                 CERROR("Unknown service %d\n", gc->gc_svc);
1382                 GOTO(out, ret = -EINVAL);
1383         }
1384         ret = 0;
1385 out:
1386         rsc_put(&rscp->h, &rsc_cache);
1387
1388         RETURN(ret);
1389 }
1390
1391 static
1392 void gss_svcsec_cleanup_req(struct ptlrpc_svcsec *svcsec,
1393                             struct ptlrpc_request *req)
1394 {
1395         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1396
1397         if (!gsd) {
1398                 CDEBUG(D_SEC, "no svc_data present. do nothing\n");
1399                 return;
1400         }
1401
1402         /* gsd->clclred.gc_ctx is NOT allocated, just set pointer
1403          * to the incoming packet buffer, so don't need free it
1404          */
1405         OBD_FREE(gsd, sizeof(*gsd));
1406         req->rq_sec_svcdata = NULL;
1407         return;
1408 }
1409
1410 static
1411 int gss_svcsec_est_payload(struct ptlrpc_svcsec *svcsec,
1412                            struct ptlrpc_request *req,
1413                            int msgsize)
1414 {
1415         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1416         ENTRY;
1417
1418         /* just return the pre-set reserve_len for init/fini/err cases.
1419          */
1420         LASSERT(svcdata);
1421         if (svcdata->is_init) {
1422                 CDEBUG(D_SEC, "is_init, reserver size %d(%d)\n",
1423                        size_round(svcdata->reserve_len),
1424                        svcdata->reserve_len);
1425                 LASSERT(svcdata->reserve_len);
1426                 LASSERT(svcdata->reserve_len % 4 == 0);
1427                 RETURN(size_round(svcdata->reserve_len));
1428         } else if (svcdata->is_err_notify) {
1429                 CDEBUG(D_SEC, "is_err_notify, reserver size %d(%d)\n",
1430                        size_round(svcdata->reserve_len),
1431                        svcdata->reserve_len);
1432                 RETURN(size_round(svcdata->reserve_len));
1433         } else if (svcdata->is_fini) {
1434                 CDEBUG(D_SEC, "is_fini, reserver size 0\n");
1435                 RETURN(0);
1436         } else {
1437                 if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_NONE ||
1438                     svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_INTEGRITY)
1439                         RETURN(size_round(GSS_MAX_AUTH_PAYLOAD));
1440                 else if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1441                         RETURN(size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1442                                             GSS_PRIVBUF_PREFIX_LEN +
1443                                             GSS_PRIVBUF_SUFFIX_LEN));
1444                 else {
1445                         CERROR("unknown gss svc %u\n", svcdata->clcred.gc_svc);
1446                         *((int *)0) = 0;
1447                         LBUG();
1448                 }
1449         }
1450         RETURN(0);
1451 }
1452
1453 int gss_svcsec_alloc_repbuf(struct ptlrpc_svcsec *svcsec,
1454                             struct ptlrpc_request *req,
1455                             int msgsize)
1456 {
1457         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1458         struct ptlrpc_reply_state *rs;
1459         int msg_payload, sec_payload;
1460         int privacy, rc;
1461         ENTRY;
1462
1463         /* determine the security type: none/auth or priv, we have
1464          * different pack scheme for them.
1465          * init/fini/err will always be treated as none/auth.
1466          */
1467         LASSERT(gsd);
1468         if (!gsd->is_init && !gsd->is_init_continue &&
1469             !gsd->is_fini && !gsd->is_err_notify &&
1470             gsd->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1471                 privacy = 1;
1472         else
1473                 privacy = 0;
1474
1475         msg_payload = privacy ? 0 : msgsize;
1476         sec_payload = gss_svcsec_est_payload(svcsec, req, msgsize);
1477
1478         rc = svcsec_alloc_reply_state(req, msg_payload, sec_payload);
1479         if (rc)
1480                 RETURN(rc);
1481
1482         rs = req->rq_reply_state;
1483         LASSERT(rs);
1484         rs->rs_msg_len = msgsize;
1485
1486         if (privacy) {
1487                 /* we can choose to let msg simply point to the rear of the
1488                  * buffer, which lead to buffer overlap when doing encryption.
1489                  * usually it's ok and it indeed passed all existing tests.
1490                  * but not sure if there will be subtle problems in the future.
1491                  * so right now we choose to alloc another new buffer. we'll
1492                  * see how it works.
1493                  */
1494 #if 0
1495                 rs->rs_msg = (struct lustre_msg *)
1496                              (rs->rs_repbuf + rs->rs_repbuf_len -
1497                               msgsize - GSS_PRIVBUF_SUFFIX_LEN);
1498 #endif
1499                 char *msgbuf;
1500
1501                 msgsize += GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1502                 OBD_ALLOC(msgbuf, msgsize);
1503                 if (!msgbuf) {
1504                         CERROR("can't alloc %d\n", msgsize);
1505                         svcsec_free_reply_state(rs);
1506                         req->rq_reply_state = NULL;
1507                         RETURN(-ENOMEM);
1508                 }
1509                 rs->rs_msg = (struct lustre_msg *)
1510                                 (msgbuf + GSS_PRIVBUF_PREFIX_LEN);
1511         }
1512
1513         req->rq_repmsg = rs->rs_msg;
1514
1515         RETURN(0);
1516 }
1517
1518 static
1519 void gss_svcsec_free_repbuf(struct ptlrpc_svcsec *svcsec,
1520                             struct ptlrpc_reply_state *rs)
1521 {
1522         unsigned long p1 = (unsigned long) rs->rs_msg;
1523         unsigned long p2 = (unsigned long) rs->rs_buf;
1524
1525         LASSERT(rs->rs_buf);
1526         LASSERT(rs->rs_msg);
1527         LASSERT(rs->rs_msg_len);
1528
1529         if (p1 < p2 || p1 >= p2 + rs->rs_buf_len) {
1530                 char *start = (char*) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1531                 int size = rs->rs_msg_len + GSS_PRIVBUF_PREFIX_LEN +
1532                            GSS_PRIVBUF_SUFFIX_LEN;
1533                 OBD_FREE(start, size);
1534         }
1535
1536         svcsec_free_reply_state(rs);
1537 }
1538
1539 struct ptlrpc_svcsec svcsec_gss = {
1540         .pss_owner              = THIS_MODULE,
1541         .pss_name               = "GSS_SVCSEC",
1542         .pss_flavor             = {PTLRPC_SEC_GSS, 0},
1543         .accept                 = gss_svcsec_accept,
1544         .authorize              = gss_svcsec_authorize,
1545         .alloc_repbuf           = gss_svcsec_alloc_repbuf,
1546         .free_repbuf            = gss_svcsec_free_repbuf,
1547         .cleanup_req            = gss_svcsec_cleanup_req,
1548 };
1549
1550 /* XXX hacking */
1551 void lgss_svc_cache_purge_all(void)
1552 {
1553         cache_purge(&rsi_cache);
1554         cache_purge(&rsc_cache);
1555 }
1556 EXPORT_SYMBOL(lgss_svc_cache_purge_all);
1557
1558 void lgss_svc_cache_flush(__u32 uid)
1559 {
1560         rsc_flush(uid);
1561 }
1562 EXPORT_SYMBOL(lgss_svc_cache_flush);
1563
1564 int gss_svc_init(void)
1565 {
1566         int rc;
1567
1568         rc = svcsec_register(&svcsec_gss);
1569         if (!rc) {
1570                 cache_register(&rsc_cache);
1571                 cache_register(&rsi_cache);
1572         }
1573         return rc;
1574 }
1575
1576 void gss_svc_exit(void)
1577 {
1578         int rc;
1579         if ((rc = cache_unregister(&rsi_cache)))
1580                 CERROR("unregister rsi cache: %d\n", rc);
1581         if ((rc = cache_unregister(&rsc_cache)))
1582                 CERROR("unregister rsc cache: %d\n", rc);
1583         if ((rc = svcsec_unregister(&svcsec_gss)))
1584                 CERROR("unregister svcsec_gss: %d\n", rc);
1585 }