Whamcloud - gitweb
c5e2ddb6c37bffe28dd649a1ee4ebb4f73f3d536
[fs/lustre-release.git] / lustre / sec / gss / svcsec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * Neil Brown <neilb@cse.unsw.edu.au>
12  * J. Bruce Fields <bfields@umich.edu>
13  * Andy Adamson <andros@umich.edu>
14  * Dug Song <dugsong@monkey.org>
15  *
16  * RPCSEC_GSS server authentication.
17  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
18  * (gssapi)
19  *
20  * The RPCSEC_GSS involves three stages:
21  *  1/ context creation
22  *  2/ data exchange
23  *  3/ context destruction
24  *
25  * Context creation is handled largely by upcalls to user-space.
26  *  In particular, GSS_Accept_sec_context is handled by an upcall
27  * Data exchange is handled entirely within the kernel
28  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
29  * Context destruction is handled in-kernel
30  *  GSS_Delete_sec_context is in-kernel
31  *
32  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
33  * The context handle and gss_token are used as a key into the rpcsec_init cache.
34  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
35  * being major_status, minor_status, context_handle, reply_token.
36  * These are sent back to the client.
37  * Sequence window management is handled by the kernel.  The window size if currently
38  * a compile time constant.
39  *
40  * When user-space is happy that a context is established, it places an entry
41  * in the rpcsec_context cache. The key for this cache is the context_handle.
42  * The content includes:
43  *   uid/gidlist - for determining access rights
44  *   mechanism type
45  *   mechanism specific information, such as a key
46  *
47  */
48
49 #define DEBUG_SUBSYSTEM S_SEC
50 #ifdef __KERNEL__
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/slab.h>
55 #include <linux/hash.h>
56 #else
57 #include <liblustre.h>
58 #endif
59
60 #include <linux/sunrpc/cache.h>
61
62 #include <libcfs/kp30.h>
63 #include <linux/obd.h>
64 #include <linux/obd_class.h>
65 #include <linux/obd_support.h>
66 #include <linux/lustre_idl.h>
67 #include <linux/lustre_net.h>
68 #include <linux/lustre_import.h>
69 #include <linux/lustre_sec.h>
70                                                                                                                         
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74
75 static inline unsigned long hash_mem(char *buf, int length, int bits)
76 {
77         unsigned long hash = 0;
78         unsigned long l = 0;
79         int len = 0;
80         unsigned char c;
81         do {
82                 if (len == length) {
83                         c = (char)len; len = -1;
84                 } else
85                         c = *buf++;
86                 l = (l << 8) | c;
87                 len++;
88                 if ((len & (BITS_PER_LONG/8-1))==0)
89                         hash = hash_long(hash^l, BITS_PER_LONG);
90         } while (len);
91         return hash >> (BITS_PER_LONG - bits);
92 }
93
94 /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
95  * into replies.
96  *
97  * Key is context handle (\x if empty) and gss_token.
98  * Content is major_status minor_status (integers) context_handle, reply_token.
99  *
100  */
101
102 #define RSI_HASHBITS    6
103 #define RSI_HASHMAX     (1<<RSI_HASHBITS)
104 #define RSI_HASHMASK    (RSI_HASHMAX-1)
105
106 struct rsi {
107         struct cache_head       h;
108         __u32                   naltype;
109         __u32                   netid;
110         __u64                   nid;
111         rawobj_t                in_handle, in_token;
112         rawobj_t                out_handle, out_token;
113         int                     major_status, minor_status;
114 };
115
116 static struct cache_head *rsi_table[RSI_HASHMAX];
117 static struct cache_detail rsi_cache;
118
119 static void rsi_free(struct rsi *rsii)
120 {
121         rawobj_free(&rsii->in_handle);
122         rawobj_free(&rsii->in_token);
123         rawobj_free(&rsii->out_handle);
124         rawobj_free(&rsii->out_token);
125 }
126
127 static void rsi_put(struct cache_head *item, struct cache_detail *cd)
128 {
129         struct rsi *rsii = container_of(item, struct rsi, h);
130         LASSERT(atomic_read(&item->refcnt) > 0);
131         if (cache_put(item, cd)) {
132                 rsi_free(rsii);
133                 OBD_FREE(rsii, sizeof(*rsii));
134         }
135 }
136
137 static inline int rsi_hash(struct rsi *item)
138 {
139         return hash_mem((char *)item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
140                 ^ hash_mem((char *)item->in_token.data, item->in_token.len, RSI_HASHBITS);
141 }
142
143 static inline int rsi_match(struct rsi *item, struct rsi *tmp)
144 {
145         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
146                 rawobj_equal(&item->in_token, &tmp->in_token));
147 }
148
149 static void rsi_request(struct cache_detail *cd,
150                         struct cache_head *h,
151                         char **bpp, int *blen)
152 {
153         struct rsi *rsii = container_of(h, struct rsi, h);
154
155         qword_addhex(bpp, blen, (char *)&rsii->naltype, sizeof(rsii->naltype));
156         qword_addhex(bpp, blen, (char *)&rsii->netid, sizeof(rsii->netid));
157         qword_addhex(bpp, blen, (char *)&rsii->nid, sizeof(rsii->nid));
158         qword_addhex(bpp, blen, (char *)rsii->in_handle.data, rsii->in_handle.len);
159         qword_addhex(bpp, blen, (char *)rsii->in_token.data, rsii->in_token.len);
160         (*bpp)[-1] = '\n';
161 }
162
163 static int
164 gssd_reply(struct rsi *item)
165 {
166         struct rsi *tmp;
167         struct cache_head **hp, **head;
168         ENTRY;
169
170         head = &rsi_cache.hash_table[rsi_hash(item)];
171         write_lock(&rsi_cache.hash_lock);
172         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
173                 tmp = container_of(*hp, struct rsi, h);
174                 if (rsi_match(tmp, item)) {
175                         cache_get(&tmp->h);
176                         clear_bit(CACHE_HASHED, &tmp->h.flags);
177                         *hp = tmp->h.next;
178                         tmp->h.next = NULL;
179                         rsi_cache.entries--;
180                         if (test_bit(CACHE_VALID, &tmp->h.flags)) {
181                                 CERROR("rsi is valid\n");
182                                 write_unlock(&rsi_cache.hash_lock);
183                                 rsi_put(&tmp->h, &rsi_cache);
184                                 RETURN(-EINVAL);
185                         }
186                         set_bit(CACHE_HASHED, &item->h.flags);
187                         item->h.next = *hp;
188                         *hp = &item->h;
189                         rsi_cache.entries++;
190                         set_bit(CACHE_VALID, &item->h.flags);
191                         item->h.last_refresh = get_seconds();
192                         write_unlock(&rsi_cache.hash_lock);
193                         cache_fresh(&rsi_cache, &tmp->h, 0);
194                         rsi_put(&tmp->h, &rsi_cache);
195                         RETURN(0);
196                 }
197         }
198         write_unlock(&rsi_cache.hash_lock);
199         RETURN(-EINVAL);
200 }
201
202 /* XXX
203  * here we just wait here for its completion or timedout. it's a
204  * hacking but works, and we'll comeup with real fix if we decided
205  * to still stick with NFS4 cache code
206  */
207 static struct rsi *
208 gssd_upcall(struct rsi *item, struct cache_req *chandle)
209 {
210         struct rsi *tmp;
211         struct cache_head **hp, **head;
212         unsigned long starttime;
213         ENTRY;
214
215         head = &rsi_cache.hash_table[rsi_hash(item)];
216         read_lock(&rsi_cache.hash_lock);
217         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
218                 tmp = container_of(*hp, struct rsi, h);
219                 if (rsi_match(tmp, item)) {
220                         LBUG();
221                         if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
222                                 CERROR("found rsi without VALID\n");
223                                 read_unlock(&rsi_cache.hash_lock);
224                                 return NULL;
225                         }
226                         *hp = tmp->h.next;
227                         tmp->h.next = NULL;
228                         rsi_cache.entries--;
229                         cache_get(&tmp->h);
230                         read_unlock(&rsi_cache.hash_lock);
231                         return tmp;
232                 }
233         }
234         cache_get(&item->h);
235         //set_bit(CACHE_HASHED, &item->h.flags);
236         item->h.next = *head;
237         *head = &item->h;
238         rsi_cache.entries++;
239         read_unlock(&rsi_cache.hash_lock);
240         //cache_get(&item->h);
241
242         cache_check(&rsi_cache, &item->h, chandle);
243         starttime = get_seconds();
244         do {
245                 set_current_state(TASK_UNINTERRUPTIBLE);
246                 schedule_timeout(HZ/2);
247                 read_lock(&rsi_cache.hash_lock);
248                 for (hp = head; *hp != NULL; hp = &tmp->h.next) {
249                         tmp = container_of(*hp, struct rsi, h);
250                         if (tmp == item)
251                                 continue;
252                         if (rsi_match(tmp, item)) {
253                                 if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
254                                         read_unlock(&rsi_cache.hash_lock);
255                                         return NULL;
256                                 }
257                                 cache_get(&tmp->h);
258                                 clear_bit(CACHE_HASHED, &tmp->h.flags);
259                                 *hp = tmp->h.next;
260                                 tmp->h.next = NULL;
261                                 rsi_cache.entries--;
262                                 read_unlock(&rsi_cache.hash_lock);
263                                 return tmp;
264                         }
265                 }
266                 read_unlock(&rsi_cache.hash_lock);
267         } while ((get_seconds() - starttime) <= 15);
268         CERROR("15s timeout while waiting cache refill\n");
269         return NULL;
270 }
271
272 static int rsi_parse(struct cache_detail *cd,
273                      char *mesg, int mlen)
274 {
275         /* context token expiry major minor context token */
276         char *buf = mesg;
277         char *ep;
278         int len;
279         struct rsi *rsii;
280         time_t expiry;
281         int status = -EINVAL;
282         ENTRY;
283
284         OBD_ALLOC(rsii, sizeof(*rsii));
285         if (!rsii) {
286                 CERROR("failed to alloc rsii\n");
287                 RETURN(-ENOMEM);
288         }
289         cache_init(&rsii->h);
290
291         /* handle */
292         len = qword_get(&mesg, buf, mlen);
293         if (len < 0)
294                 goto out;
295         status = -ENOMEM;
296         if (rawobj_alloc(&rsii->in_handle, buf, len))
297                 goto out;
298
299         /* token */
300         len = qword_get(&mesg, buf, mlen);
301         status = -EINVAL;
302         if (len < 0)
303                 goto out;;
304         status = -ENOMEM;
305         if (rawobj_alloc(&rsii->in_token, buf, len))
306                 goto out;
307
308         /* expiry */
309         expiry = get_expiry(&mesg);
310         status = -EINVAL;
311         if (expiry == 0)
312                 goto out;
313
314         /* major/minor */
315         len = qword_get(&mesg, buf, mlen);
316         if (len < 0)
317                 goto out;
318         if (len == 0) {
319                 goto out;
320         } else {
321                 rsii->major_status = simple_strtoul(buf, &ep, 10);
322                 if (*ep)
323                         goto out;
324                 len = qword_get(&mesg, buf, mlen);
325                 if (len <= 0)
326                         goto out;
327                 rsii->minor_status = simple_strtoul(buf, &ep, 10);
328                 if (*ep)
329                         goto out;
330
331                 /* out_handle */
332                 len = qword_get(&mesg, buf, mlen);
333                 if (len < 0)
334                         goto out;
335                 status = -ENOMEM;
336                 if (rawobj_alloc(&rsii->out_handle, buf, len))
337                         goto out;
338
339                 /* out_token */
340                 len = qword_get(&mesg, buf, mlen);
341                 status = -EINVAL;
342                 if (len < 0)
343                         goto out;
344                 status = -ENOMEM;
345                 if (rawobj_alloc(&rsii->out_token, buf, len))
346                         goto out;
347         }
348         rsii->h.expiry_time = expiry;
349         status = gssd_reply(rsii);
350 out:
351         if (rsii)
352                 rsi_put(&rsii->h, &rsi_cache);
353         RETURN(status);
354 }
355
356 static struct cache_detail rsi_cache = {
357         .hash_size      = RSI_HASHMAX,
358         .hash_table     = rsi_table,
359         .name           = "auth.ptlrpcs.init",
360         .cache_put      = rsi_put,
361         .cache_request  = rsi_request,
362         .cache_parse    = rsi_parse,
363 };
364
365 /*
366  * The rpcsec_context cache is used to store a context that is
367  * used in data exchange.
368  * The key is a context handle. The content is:
369  *  uid, gidlist, mechanism, service-set, mech-specific-data
370  */
371
372 #define RSC_HASHBITS    10
373 #define RSC_HASHMAX     (1<<RSC_HASHBITS)
374 #define RSC_HASHMASK    (RSC_HASHMAX-1)
375
376 #define GSS_SEQ_WIN     512
377
378 struct gss_svc_seq_data {
379         /* highest seq number seen so far: */
380         __u32                   sd_max;
381         /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
382          * sd_win is nonzero iff sequence number i has been seen already: */
383         unsigned long           sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
384         spinlock_t              sd_lock;
385 };
386
387 struct rsc {
388         struct cache_head       h;
389         rawobj_t                handle;
390         __u32                   remote_realm;
391         struct vfs_cred         cred;
392         uid_t                   mapped_uid;
393         struct gss_svc_seq_data seqdata;
394         struct gss_ctx         *mechctx;
395 };
396
397 static struct cache_head *rsc_table[RSC_HASHMAX];
398 static struct cache_detail rsc_cache;
399
400 static void rsc_free(struct rsc *rsci)
401 {
402         rawobj_free(&rsci->handle);
403         if (rsci->mechctx)
404                 kgss_delete_sec_context(&rsci->mechctx);
405 #if 0
406         if (rsci->cred.vc_ginfo)
407                 put_group_info(rsci->cred.vc_ginfo);
408 #endif
409 }
410
411 static void rsc_put(struct cache_head *item, struct cache_detail *cd)
412 {
413         struct rsc *rsci = container_of(item, struct rsc, h);
414
415         LASSERT(atomic_read(&item->refcnt) > 0);
416         if (cache_put(item, cd)) {
417                 rsc_free(rsci);
418                 OBD_FREE(rsci, sizeof(*rsci));
419         }
420 }
421
422 static inline int
423 rsc_hash(struct rsc *rsci)
424 {
425         return hash_mem((char *)rsci->handle.data,
426                         rsci->handle.len, RSC_HASHBITS);
427 }
428
429 static inline int
430 rsc_match(struct rsc *new, struct rsc *tmp)
431 {
432         return rawobj_equal(&new->handle, &tmp->handle);
433 }
434
435 static struct rsc *rsc_lookup(struct rsc *item, int set)
436 {
437         struct rsc *tmp = NULL;
438         struct cache_head **hp, **head;
439         head = &rsc_cache.hash_table[rsc_hash(item)];
440         ENTRY;
441
442         if (set)
443                 write_lock(&rsc_cache.hash_lock);
444         else
445                 read_lock(&rsc_cache.hash_lock);
446         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
447                 tmp = container_of(*hp, struct rsc, h);
448                 if (!rsc_match(tmp, item))
449                         continue;
450                 cache_get(&tmp->h);
451                 if (!set) {
452                         goto out_noset;
453                 }
454                 *hp = tmp->h.next;
455                 tmp->h.next = NULL;
456                 clear_bit(CACHE_HASHED, &tmp->h.flags);
457                 rsc_put(&tmp->h, &rsc_cache);
458                 goto out_set;
459         }
460         /* Didn't find anything */
461         if (!set)
462                 goto out_nada;
463         rsc_cache.entries++;
464 out_set:
465         set_bit(CACHE_HASHED, &item->h.flags);
466         item->h.next = *head;
467         *head = &item->h;
468         write_unlock(&rsc_cache.hash_lock);
469         cache_fresh(&rsc_cache, &item->h, item->h.expiry_time);
470         cache_get(&item->h);
471         RETURN(item);
472 out_nada:
473         tmp = NULL;
474 out_noset:
475         read_unlock(&rsc_cache.hash_lock);
476         RETURN(tmp);
477 }
478                                                                                                                         
479 static int rsc_parse(struct cache_detail *cd,
480                      char *mesg, int mlen)
481 {
482         /* contexthandle expiry [ uid gid N <n gids> mechname
483          * ...mechdata... ] */
484         char *buf = mesg;
485         int len, rv;
486         struct rsc *rsci, *res = NULL;
487         time_t expiry;
488         int status = -EINVAL;
489
490         OBD_ALLOC(rsci, sizeof(*rsci));
491         if (!rsci) {
492                 CERROR("fail to alloc rsci\n");
493                 return -ENOMEM;
494         }
495         cache_init(&rsci->h);
496
497         /* context handle */
498         len = qword_get(&mesg, buf, mlen);
499         if (len < 0) goto out;
500         status = -ENOMEM;
501         if (rawobj_alloc(&rsci->handle, buf, len))
502                 goto out;
503
504         /* expiry */
505         expiry = get_expiry(&mesg);
506         status = -EINVAL;
507         if (expiry == 0)
508                 goto out;
509
510         /* remote flag */
511         rv = get_int(&mesg, (int *)&rsci->remote_realm);
512         if (rv) {
513                 CERROR("fail to get remote flag\n");
514                 goto out;
515         }
516
517         /* mapped uid */
518         rv = get_int(&mesg, (int *)&rsci->mapped_uid);
519         if (rv) {
520                 CERROR("fail to get mapped uid\n");
521                 goto out;
522         }
523
524         /* uid, or NEGATIVE */
525         rv = get_int(&mesg, (int *)&rsci->cred.vc_uid);
526         if (rv == -EINVAL)
527                 goto out;
528         if (rv == -ENOENT) {
529                 CERROR("NOENT? set rsc entry negative\n");
530                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
531         } else {
532                 struct gss_api_mech *gm;
533                 rawobj_t tmp_buf;
534                 __u64 ctx_expiry;
535
536                 /* gid */
537                 if (get_int(&mesg, (int *)&rsci->cred.vc_gid))
538                         goto out;
539
540                 /* mech name */
541                 len = qword_get(&mesg, buf, mlen);
542                 if (len < 0)
543                         goto out;
544                 gm = kgss_name_to_mech(buf);
545                 status = -EOPNOTSUPP;
546                 if (!gm)
547                         goto out;
548
549                 status = -EINVAL;
550                 /* mech-specific data: */
551                 len = qword_get(&mesg, buf, mlen);
552                 if (len < 0) {
553                         kgss_mech_put(gm);
554                         goto out;
555                 }
556                 tmp_buf.len = len;
557                 tmp_buf.data = (unsigned char *)buf;
558                 if (kgss_import_sec_context(&tmp_buf, gm, &rsci->mechctx)) {
559                         kgss_mech_put(gm);
560                         goto out;
561                 }
562
563                 /* currently the expiry time passed down from user-space
564                  * is invalid, here we retrive it from mech.
565                  */
566                 if (kgss_inquire_context(rsci->mechctx, &ctx_expiry)) {
567                         CERROR("unable to get expire time, drop it\n");
568                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
569                         kgss_mech_put(gm);
570                         goto out;
571                 }
572                 expiry = (time_t) ctx_expiry;
573
574                 kgss_mech_put(gm);
575         }
576         rsci->h.expiry_time = expiry;
577         spin_lock_init(&rsci->seqdata.sd_lock);
578         res = rsc_lookup(rsci, 1);
579         rsc_put(&res->h, &rsc_cache);
580         status = 0;
581 out:
582         if (rsci)
583                 rsc_put(&rsci->h, &rsc_cache);
584         return status;
585 }
586
587 /*
588  * flush all entries with @uid. @uid == -1 will match all.
589  * we only know the uid, maybe netid/nid in the future, in all cases
590  * we must search the whole cache
591  */
592 static void rsc_flush(uid_t uid)
593 {
594         struct cache_head **ch;
595         struct rsc *rscp;
596         int n;
597         ENTRY;
598
599         write_lock(&rsc_cache.hash_lock);
600         for (n = 0; n < RSC_HASHMAX; n++) {
601                 for (ch = &rsc_cache.hash_table[n]; *ch;) {
602                         rscp = container_of(*ch, struct rsc, h);
603                         if (uid == -1 || rscp->cred.vc_uid == uid) {
604                                 /* it seems simply set NEGATIVE doesn't work */
605                                 *ch = (*ch)->next;
606                                 rscp->h.next = NULL;
607                                 cache_get(&rscp->h);
608                                 set_bit(CACHE_NEGATIVE, &rscp->h.flags);
609                                 clear_bit(CACHE_HASHED, &rscp->h.flags);
610                                 CDEBUG(D_SEC, "flush rsc %p for uid %u\n",
611                                        rscp, rscp->cred.vc_uid);
612                                 rsc_put(&rscp->h, &rsc_cache);
613                                 rsc_cache.entries--;
614                                 continue;
615                         }
616                         ch = &((*ch)->next);
617                 }
618         }
619         write_unlock(&rsc_cache.hash_lock);
620         EXIT;
621 }
622
623 static struct cache_detail rsc_cache = {
624         .hash_size      = RSC_HASHMAX,
625         .hash_table     = rsc_table,
626         .name           = "auth.ptlrpcs.context",
627         .cache_put      = rsc_put,
628         .cache_parse    = rsc_parse,
629 };
630
631 static struct rsc *
632 gss_svc_searchbyctx(rawobj_t *handle)
633 {
634         struct rsc rsci;
635         struct rsc *found;
636
637         rsci.handle = *handle;
638         found = rsc_lookup(&rsci, 0);
639         if (!found)
640                 return NULL;
641
642         if (cache_check(&rsc_cache, &found->h, NULL))
643                 return NULL;
644
645         return found;
646 }
647
648 struct gss_svc_data {
649         /* decoded gss client cred: */
650         struct rpc_gss_wire_cred        clcred;
651         /* internal used status */
652         unsigned int                    is_init:1,
653                                         is_init_continue:1,
654                                         is_err_notify:1,
655                                         is_fini:1;
656         int                             reserve_len;
657 };
658
659 /* FIXME
660  * again hacking: only try to give the svcgssd a chance to handle
661  * upcalls.
662  */
663 struct cache_deferred_req* my_defer(struct cache_req *req)
664 {
665         yield();
666         return NULL;
667 }
668 static struct cache_req my_chandle = {my_defer};
669
670 /* Implements sequence number algorithm as specified in RFC 2203. */
671 static int
672 gss_check_seq_num(struct gss_svc_seq_data *sd, __u32 seq_num)
673 {
674         int rc = 0;
675
676         spin_lock(&sd->sd_lock);
677         if (seq_num > sd->sd_max) {
678                 if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
679                         memset(sd->sd_win, 0, sizeof(sd->sd_win));
680                         sd->sd_max = seq_num;
681                 } else {
682                         while(sd->sd_max < seq_num) {
683                                 sd->sd_max++;
684                                 __clear_bit(sd->sd_max % GSS_SEQ_WIN,
685                                             sd->sd_win);
686                         }
687                 }
688                 __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
689                 goto exit;
690         } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
691                 CERROR("seq %u too low: max %u, win %d\n",
692                         seq_num, sd->sd_max, GSS_SEQ_WIN);
693                 rc = 1;
694                 goto exit;
695         }
696
697         if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) {
698                 CERROR("seq %u is replay: max %u, win %d\n",
699                         seq_num, sd->sd_max, GSS_SEQ_WIN);
700                 rc = 1;
701         }
702 exit:
703         spin_unlock(&sd->sd_lock);
704         return rc;
705 }
706
707 static int
708 gss_svc_verify_request(struct ptlrpc_request *req,
709                        struct rsc *rsci,
710                        struct rpc_gss_wire_cred *gc,
711                        __u32 *vp, __u32 vlen)
712 {
713         struct ptlrpcs_wire_hdr *sec_hdr;
714         struct gss_ctx *ctx = rsci->mechctx;
715         __u32 maj_stat;
716         rawobj_t msg;
717         rawobj_t mic;
718         ENTRY;
719
720         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
721
722         req->rq_reqmsg = (struct lustre_msg *) (req->rq_reqbuf + sizeof(*sec_hdr));
723         req->rq_reqlen = sec_hdr->msg_len;
724
725         msg.len = sec_hdr->msg_len;
726         msg.data = (__u8 *)req->rq_reqmsg;
727
728         mic.len = le32_to_cpu(*vp++);
729         mic.data = (unsigned char *)vp;
730         vlen -= 4;
731
732         if (mic.len > vlen) {
733                 CERROR("checksum len %d, while buffer len %d\n",
734                         mic.len, vlen);
735                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
736         }
737
738         if (mic.len > 256) {
739                 CERROR("invalid mic len %d\n", mic.len);
740                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
741         }
742
743         maj_stat = kgss_verify_mic(ctx, &msg, &mic, NULL);
744         if (maj_stat != GSS_S_COMPLETE) {
745                 CERROR("MIC verification error: major %x\n", maj_stat);
746                 RETURN(maj_stat);
747         }
748
749         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
750                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
751                         req, req->rq_reqmsg->opc, req->rq_xid,
752                         req->rq_reqmsg->transno);
753                 RETURN(GSS_S_DUPLICATE_TOKEN);
754         }
755
756         RETURN(GSS_S_COMPLETE);
757 }
758
759 static int
760 gss_svc_unseal_request(struct ptlrpc_request *req,
761                        struct rsc *rsci,
762                        struct rpc_gss_wire_cred *gc,
763                        __u32 *vp, __u32 vlen)
764 {
765         struct ptlrpcs_wire_hdr *sec_hdr;
766         struct gss_ctx *ctx = rsci->mechctx;
767         rawobj_t cipher_text, plain_text;
768         __u32 major;
769         ENTRY;
770
771         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
772
773         if (vlen < 4) {
774                 CERROR("vlen only %u\n", vlen);
775                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
776         }
777
778         cipher_text.len = le32_to_cpu(*vp++);
779         cipher_text.data = (__u8 *) vp;
780         vlen -= 4;
781         
782         if (cipher_text.len > vlen) {
783                 CERROR("cipher claimed %u while buf only %u\n",
784                         cipher_text.len, vlen);
785                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
786         }
787
788         plain_text = cipher_text;
789
790         major = kgss_unwrap(ctx, GSS_C_QOP_DEFAULT, &cipher_text, &plain_text);
791         if (major) {
792                 CERROR("unwrap error 0x%x\n", major);
793                 RETURN(major);
794         }
795
796         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
797                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
798                         req, req->rq_reqmsg->opc, req->rq_xid,
799                         req->rq_reqmsg->transno);
800                 RETURN(GSS_S_DUPLICATE_TOKEN);
801         }
802
803         req->rq_reqmsg = (struct lustre_msg *) (vp);
804         req->rq_reqlen = plain_text.len;
805
806         CDEBUG(D_SEC, "msg len %d\n", req->rq_reqlen);
807
808         RETURN(GSS_S_COMPLETE);
809 }
810
811 static int
812 gss_pack_err_notify(struct ptlrpc_request *req,
813                     __u32 major, __u32 minor)
814 {
815         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
816         __u32 reslen, *resp, *reslenp;
817         char  nidstr[PTL_NALFMT_SIZE];
818         const __u32 secdata_len = 7 * 4;
819         int rc;
820         ENTRY;
821
822         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_ERR_NOTIFY|OBD_FAIL_ONCE, -EINVAL);
823
824         LASSERT(svcdata);
825         svcdata->is_err_notify = 1;
826         svcdata->reserve_len = 7 * 4;
827
828         rc = lustre_pack_reply(req, 0, NULL, NULL);
829         if (rc) {
830                 CERROR("could not pack reply, err %d\n", rc);
831                 RETURN(rc);
832         }
833
834         LASSERT(req->rq_reply_state);
835         LASSERT(req->rq_reply_state->rs_repbuf);
836         LASSERT(req->rq_reply_state->rs_repbuf_len >= secdata_len);
837         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
838
839         /* header */
840         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
841         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
842         *resp++ = cpu_to_le32(req->rq_replen);
843         reslenp = resp++;
844
845         /* skip lustre msg */
846         resp += req->rq_replen / 4;
847         reslen = svcdata->reserve_len;
848
849         /* gss replay:
850          * version, subflavor, notify, major, minor,
851          * obj1(fake), obj2(fake)
852          */
853         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
854         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
855         *resp++ = cpu_to_le32(PTLRPC_GSS_PROC_ERR);
856         *resp++ = cpu_to_le32(major);
857         *resp++ = cpu_to_le32(minor);
858         *resp++ = 0;
859         *resp++ = 0;
860         reslen -= (4 * 4);
861         /* the actual sec data length */
862         *reslenp = cpu_to_le32(secdata_len);
863
864         req->rq_reply_state->rs_repdata_len += (secdata_len);
865         CDEBUG(D_SEC, "prepare gss error notify(0x%x/0x%x) to %s\n",
866                major, minor,
867                portals_nid2str(req->rq_peer.peer_ni->pni_number,
868                                req->rq_peer.peer_id.nid, nidstr));
869         RETURN(0);
870 }
871
872 static void dump_cache_head(struct cache_head *h)
873 {
874         CWARN("ref %d, fl %lx, n %p, t %ld, %ld\n",
875               atomic_read(&h->refcnt), h->flags, h->next,
876               h->expiry_time, h->last_refresh);
877 }
878 static void dump_rsi(struct rsi *rsi)
879 {
880         CWARN("dump rsi %p\n", rsi);
881         dump_cache_head(&rsi->h);
882         CWARN("%x,%x,%llx\n", rsi->naltype, rsi->netid, rsi->nid);
883         CWARN("len %d, d %p\n", rsi->in_handle.len, rsi->in_handle.data);
884         CWARN("len %d, d %p\n", rsi->in_token.len, rsi->in_token.data);
885         CWARN("len %d, d %p\n", rsi->out_handle.len, rsi->out_handle.data);
886         CWARN("len %d, d %p\n", rsi->out_token.len, rsi->out_token.data);
887 }
888
889 static int
890 gss_svcsec_handle_init(struct ptlrpc_request *req,
891                        struct rpc_gss_wire_cred *gc,
892                        __u32 *secdata, __u32 seclen,
893                        enum ptlrpcs_error *res)
894 {
895         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
896         struct rsc          *rsci;
897         struct rsi          *rsikey, *rsip;
898         rawobj_t             tmpobj;
899         __u32 reslen,       *resp, *reslenp;
900         char                 nidstr[PTL_NALFMT_SIZE];
901         int                  rc;
902         ENTRY;
903
904         LASSERT(svcdata);
905
906         CDEBUG(D_SEC, "processing gss init(%d) request from %s\n", gc->gc_proc,
907                portals_nid2str(req->rq_peer.peer_ni->pni_number,
908                                req->rq_peer.peer_id.nid, nidstr));
909
910         *res = PTLRPCS_BADCRED;
911         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_INIT_REQ|OBD_FAIL_ONCE, SVC_DROP);
912
913         if (gc->gc_proc == RPC_GSS_PROC_INIT &&
914             gc->gc_ctx.len != 0) {
915                 CERROR("proc %d, ctx_len %d: not really init?\n",
916                 gc->gc_proc == RPC_GSS_PROC_INIT, gc->gc_ctx.len);
917                 RETURN(SVC_DROP);
918         }
919
920         OBD_ALLOC(rsikey, sizeof(*rsikey));
921         if (!rsikey) {
922                 CERROR("out of memory\n");
923                 RETURN(SVC_DROP);
924         }
925         cache_init(&rsikey->h);
926
927         if (rawobj_dup(&rsikey->in_handle, &gc->gc_ctx)) {
928                 CERROR("fail to dup context handle\n");
929                 GOTO(out_rsikey, rc = SVC_DROP);
930         }
931         *res = PTLRPCS_BADVERF;
932         if (rawobj_extract(&tmpobj, &secdata, &seclen)) {
933                 CERROR("can't extract token\n");
934                 GOTO(out_rsikey, rc = SVC_DROP);
935         }
936         if (rawobj_dup(&rsikey->in_token, &tmpobj)) {
937                 CERROR("can't duplicate token\n");
938                 GOTO(out_rsikey, rc = SVC_DROP);
939         }
940
941         rsikey->naltype = (__u32) req->rq_peer.peer_ni->pni_number;
942         rsikey->netid = 0;
943         rsikey->nid = (__u64) req->rq_peer.peer_id.nid;
944
945         rsip = gssd_upcall(rsikey, &my_chandle);
946         if (!rsip) {
947                 CERROR("error in gssd_upcall.\n");
948
949                 rc = SVC_COMPLETE;
950                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
951                         rc = SVC_DROP;
952
953                 GOTO(out_rsikey, rc);
954         }
955
956         rsci = gss_svc_searchbyctx(&rsip->out_handle);
957         if (!rsci) {
958                 CERROR("rsci still not mature yet?\n");
959
960                 rc = SVC_COMPLETE;
961                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
962                         rc = SVC_DROP;
963
964                 GOTO(out_rsip, rc);
965         }
966         CDEBUG(D_SEC, "svcsec create gss context %p(%u@%s)\n",
967                rsci, rsci->cred.vc_uid,
968                portals_nid2str(req->rq_peer.peer_ni->pni_number,
969                                req->rq_peer.peer_id.nid, nidstr));
970
971         svcdata->is_init = 1;
972         svcdata->reserve_len = 6 * 4 +
973                 size_round4(rsip->out_handle.len) +
974                 size_round4(rsip->out_token.len);
975
976         rc = lustre_pack_reply(req, 0, NULL, NULL);
977         if (rc) {
978                 CERROR("failed to pack reply, rc = %d\n", rc);
979                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
980                 GOTO(out, rc = SVC_DROP);
981         }
982
983         /* header */
984         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
985         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
986         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
987         *resp++ = cpu_to_le32(req->rq_replen);
988         reslenp = resp++;
989
990         resp += req->rq_replen / 4;
991         reslen = svcdata->reserve_len;
992
993         /* gss reply:
994          * status, major, minor, seq, out_handle, out_token
995          */
996         *resp++ = cpu_to_le32(PTLRPCS_OK);
997         *resp++ = cpu_to_le32(rsip->major_status);
998         *resp++ = cpu_to_le32(rsip->minor_status);
999         *resp++ = cpu_to_le32(GSS_SEQ_WIN);
1000         reslen -= (4 * 4);
1001         if (rawobj_serialize(&rsip->out_handle,
1002                              &resp, &reslen)) {
1003                 dump_rsi(rsip);
1004                 dump_rsi(rsikey);
1005                 LBUG();
1006         }
1007         if (rawobj_serialize(&rsip->out_token,
1008                              &resp, &reslen)) {
1009                 dump_rsi(rsip);
1010                 dump_rsi(rsikey);
1011                 LBUG();
1012         }
1013         /* the actual sec data length */
1014         *reslenp = cpu_to_le32(svcdata->reserve_len - reslen);
1015
1016         req->rq_reply_state->rs_repdata_len += le32_to_cpu(*reslenp);
1017         CDEBUG(D_SEC, "req %p: msgsize %d, authsize %d, "
1018                "total size %d\n", req, req->rq_replen,
1019                le32_to_cpu(*reslenp),
1020                req->rq_reply_state->rs_repdata_len);
1021
1022         *res = PTLRPCS_OK;
1023
1024         req->rq_auth_uid = rsci->cred.vc_uid;
1025         req->rq_remote_realm = rsci->remote_realm;
1026         req->rq_mapped_uid = rsci->mapped_uid;
1027
1028         /* This is simplified since right now we doesn't support
1029          * INIT_CONTINUE yet.
1030          */
1031         if (gc->gc_proc == RPC_GSS_PROC_INIT) {
1032                 struct ptlrpcs_wire_hdr *hdr;
1033
1034                 hdr = buf_to_sec_hdr(req->rq_reqbuf);
1035                 req->rq_reqmsg = buf_to_lustre_msg(req->rq_reqbuf);
1036                 req->rq_reqlen = hdr->msg_len;
1037
1038                 rc = SVC_LOGIN;
1039         } else
1040                 rc = SVC_COMPLETE;
1041
1042 out:
1043         rsc_put(&rsci->h, &rsc_cache);
1044 out_rsip:
1045         rsi_put(&rsip->h, &rsi_cache);
1046 out_rsikey:
1047         rsi_put(&rsikey->h, &rsi_cache);
1048
1049         RETURN(rc);
1050 }
1051
1052 static int
1053 gss_svcsec_handle_data(struct ptlrpc_request *req,
1054                        struct rpc_gss_wire_cred *gc,
1055                        __u32 *secdata, __u32 seclen,
1056                        enum ptlrpcs_error *res)
1057 {
1058         struct rsc          *rsci;
1059         char                 nidstr[PTL_NALFMT_SIZE];
1060         __u32                major;
1061         int                  rc;
1062         ENTRY;
1063
1064         *res = PTLRPCS_GSS_CREDPROBLEM;
1065
1066         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1067         if (!rsci) {
1068                 CWARN("Invalid gss context handle from %s\n",
1069                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1070                                        req->rq_peer.peer_id.nid, nidstr));
1071                 major = GSS_S_NO_CONTEXT;
1072                 goto notify_err;
1073         }
1074
1075         switch (gc->gc_svc) {
1076         case PTLRPC_GSS_SVC_INTEGRITY:
1077                 major = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1078                 if (major == GSS_S_COMPLETE)
1079                         break;
1080
1081                 CWARN("fail in verify:0x%x: ctx %p@%s\n", major, rsci,
1082                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1083                                        req->rq_peer.peer_id.nid, nidstr));
1084                 goto notify_err;
1085         case PTLRPC_GSS_SVC_PRIVACY:
1086                 major = gss_svc_unseal_request(req, rsci, gc, secdata, seclen);
1087                 if (major == GSS_S_COMPLETE)
1088                         break;
1089
1090                 CWARN("fail in decrypt:0x%x: ctx %p@%s\n", major, rsci,
1091                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1092                                        req->rq_peer.peer_id.nid, nidstr));
1093                 goto notify_err;
1094         default:
1095                 CERROR("unsupported gss service %d\n", gc->gc_svc);
1096                 GOTO(out, rc = SVC_DROP);
1097         }
1098
1099         req->rq_auth_uid = rsci->cred.vc_uid;
1100         req->rq_remote_realm = rsci->remote_realm;
1101         req->rq_mapped_uid = rsci->mapped_uid;
1102
1103         *res = PTLRPCS_OK;
1104         GOTO(out, rc = SVC_OK);
1105
1106 notify_err:
1107         if (gss_pack_err_notify(req, major, 0))
1108                 rc = SVC_DROP;
1109         else
1110                 rc = SVC_COMPLETE;
1111 out:
1112         if (rsci)
1113                 rsc_put(&rsci->h, &rsc_cache);
1114         RETURN(rc);
1115 }
1116
1117 static int
1118 gss_svcsec_handle_destroy(struct ptlrpc_request *req,
1119                           struct rpc_gss_wire_cred *gc,
1120                           __u32 *secdata, __u32 seclen,
1121                           enum ptlrpcs_error *res)
1122 {
1123         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1124         struct rsc          *rsci;
1125         char                 nidstr[PTL_NALFMT_SIZE];
1126         int                  rc;
1127         ENTRY;
1128
1129         LASSERT(svcdata);
1130         *res = PTLRPCS_GSS_CREDPROBLEM;
1131
1132         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1133         if (!rsci) {
1134                 CWARN("invalid gss context handle for destroy.\n");
1135                 RETURN(SVC_DROP);
1136         }
1137
1138         if (gc->gc_svc != PTLRPC_GSS_SVC_INTEGRITY) {
1139                 CERROR("service %d is not supported in destroy.\n",
1140                         gc->gc_svc);
1141                 GOTO(out, rc = SVC_DROP);
1142         }
1143
1144         *res = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1145         if (*res)
1146                 GOTO(out, rc = SVC_DROP);
1147
1148         /* compose reply, which is actually nothing */
1149         svcdata->is_fini = 1;
1150         if (lustre_pack_reply(req, 0, NULL, NULL))
1151                 GOTO(out, rc = SVC_DROP);
1152
1153         CDEBUG(D_SEC, "svcsec destroy gss context %p(%u@%s)\n",
1154                rsci, rsci->cred.vc_uid,
1155                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1156                                req->rq_peer.peer_id.nid, nidstr));
1157
1158         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1159         *res = PTLRPCS_OK;
1160         rc = SVC_LOGOUT;
1161 out:
1162         rsc_put(&rsci->h, &rsc_cache);
1163         RETURN(rc);
1164 }
1165
1166 /*
1167  * let incomming request go through security check:
1168  *  o context establishment: invoke user space helper
1169  *  o data exchange: verify/decrypt
1170  *  o context destruction: mark context invalid
1171  *
1172  * in most cases, error will result to drop the packet silently.
1173  */
1174 static int
1175 gss_svcsec_accept(struct ptlrpc_request *req, enum ptlrpcs_error *res)
1176 {
1177         struct gss_svc_data *svcdata;
1178         struct rpc_gss_wire_cred *gc;
1179         struct ptlrpcs_wire_hdr *sec_hdr;
1180         __u32 seclen, *secdata, version, subflavor;
1181         int rc;
1182         ENTRY;
1183
1184         CDEBUG(D_SEC, "request %p\n", req);
1185         LASSERT(req->rq_reqbuf);
1186         LASSERT(req->rq_reqbuf_len);
1187
1188         *res = PTLRPCS_BADCRED;
1189
1190         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1191         LASSERT(sec_hdr->flavor == PTLRPC_SEC_GSS);
1192
1193         seclen = req->rq_reqbuf_len - sizeof(*sec_hdr) - sec_hdr->msg_len;
1194         secdata = (__u32 *) buf_to_sec_data(req->rq_reqbuf);
1195
1196         if (sec_hdr->sec_len > seclen) {
1197                 CERROR("seclen %d, while max buf %d\n",
1198                         sec_hdr->sec_len, seclen);
1199                 RETURN(SVC_DROP);
1200         }
1201
1202         if (seclen < 6 * 4) {
1203                 CERROR("sec size %d too small\n", seclen);
1204                 RETURN(SVC_DROP);
1205         }
1206
1207         LASSERT(!req->rq_sec_svcdata);
1208         OBD_ALLOC(svcdata, sizeof(*svcdata));
1209         if (!svcdata) {
1210                 CERROR("fail to alloc svcdata\n");
1211                 RETURN(SVC_DROP);
1212         }
1213         req->rq_sec_svcdata = svcdata;
1214         gc = &svcdata->clcred;
1215
1216         /* Now secdata/seclen is what we want to parse
1217          */
1218         version = le32_to_cpu(*secdata++);      /* version */
1219         subflavor = le32_to_cpu(*secdata++);    /* subflavor */
1220         gc->gc_proc = le32_to_cpu(*secdata++);  /* proc */
1221         gc->gc_seq = le32_to_cpu(*secdata++);   /* seq */
1222         gc->gc_svc = le32_to_cpu(*secdata++);   /* service */
1223         seclen -= 5 * 4;
1224
1225         CDEBUG(D_SEC, "wire gss_hdr: %u/%u/%u/%u/%u\n",
1226                version, subflavor, gc->gc_proc, gc->gc_seq, gc->gc_svc);
1227
1228         if (version != PTLRPC_SEC_GSS_VERSION) {
1229                 CERROR("gss version mismatch: %d - %d\n",
1230                         version, PTLRPC_SEC_GSS_VERSION);
1231                 GOTO(err_free, rc = SVC_DROP);
1232         }
1233
1234         if (rawobj_extract(&gc->gc_ctx, &secdata, &seclen)) {
1235                 CERROR("fail to obtain gss context handle\n");
1236                 GOTO(err_free, rc = SVC_DROP);
1237         }
1238
1239         *res = PTLRPCS_BADVERF;
1240         switch(gc->gc_proc) {
1241         case RPC_GSS_PROC_INIT:
1242         case RPC_GSS_PROC_CONTINUE_INIT:
1243                 rc = gss_svcsec_handle_init(req, gc, secdata, seclen, res);
1244                 break;
1245         case RPC_GSS_PROC_DATA:
1246                 rc = gss_svcsec_handle_data(req, gc, secdata, seclen, res);
1247                 break;
1248         case RPC_GSS_PROC_DESTROY:
1249                 rc = gss_svcsec_handle_destroy(req, gc, secdata, seclen, res);
1250                 break;
1251         default:
1252                 rc = SVC_DROP;
1253                 LBUG();
1254         }
1255
1256 err_free:
1257         if (rc == SVC_DROP && req->rq_sec_svcdata) {
1258                 OBD_FREE(req->rq_sec_svcdata, sizeof(struct gss_svc_data));
1259                 req->rq_sec_svcdata = NULL;
1260         }
1261
1262         RETURN(rc);
1263 }
1264
1265 static int
1266 gss_svcsec_authorize(struct ptlrpc_request *req)
1267 {
1268         struct ptlrpc_reply_state *rs = req->rq_reply_state;
1269         struct gss_svc_data *gsd = (struct gss_svc_data *)req->rq_sec_svcdata;
1270         struct rpc_gss_wire_cred  *gc = &gsd->clcred;
1271         struct rsc                *rscp;
1272         struct ptlrpcs_wire_hdr   *sec_hdr;
1273         rawobj_buf_t               msg_buf;
1274         rawobj_t                   cipher_buf;
1275         __u32                     *vp, *vpsave, major, vlen, seclen;
1276         rawobj_t                   lmsg, mic;
1277         int                        ret;
1278         ENTRY;
1279
1280         LASSERT(rs);
1281         LASSERT(rs->rs_repbuf);
1282         LASSERT(gsd);
1283
1284         if (gsd->is_init || gsd->is_init_continue ||
1285             gsd->is_err_notify || gsd->is_fini) {
1286                 /* nothing to do in these cases */
1287                 CDEBUG(D_SEC, "req %p: init/fini/err\n", req);
1288                 RETURN(0);
1289         }
1290
1291         if (gc->gc_proc != RPC_GSS_PROC_DATA) {
1292                 CERROR("proc %d not support\n", gc->gc_proc);
1293                 RETURN(-EINVAL);
1294         }
1295
1296         rscp = gss_svc_searchbyctx(&gc->gc_ctx);
1297         if (!rscp) {
1298                 CERROR("ctx disapeared under us?\n");
1299                 RETURN(-EINVAL);
1300         }
1301
1302         sec_hdr = (struct ptlrpcs_wire_hdr *) rs->rs_repbuf;
1303         switch (gc->gc_svc) {
1304         case  PTLRPC_GSS_SVC_INTEGRITY:
1305                 /* prepare various pointers */
1306                 lmsg.len = req->rq_replen;
1307                 lmsg.data = (__u8 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1308                 vp = (__u32 *) (lmsg.data + lmsg.len);
1309                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr) - lmsg.len;
1310                 seclen = vlen;
1311
1312                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1313                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_AUTH);
1314                 sec_hdr->msg_len = cpu_to_le32(req->rq_replen);
1315
1316                 /* standard gss hdr */
1317                 LASSERT(vlen >= 7 * 4);
1318                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1319                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1320                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1321                 *vp++ = cpu_to_le32(gc->gc_seq);
1322                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_INTEGRITY);
1323                 *vp++ = 0;      /* fake ctx handle */
1324                 vpsave = vp++;  /* reserve size */
1325                 vlen -= 7 * 4;
1326
1327                 mic.len = vlen;
1328                 mic.data = (unsigned char *)vp;
1329
1330                 major = kgss_get_mic(rscp->mechctx, 0, &lmsg, &mic);
1331                 if (major) {
1332                         CERROR("fail to get MIC: 0x%x\n", major);
1333                         GOTO(out, ret = -EINVAL);
1334                 }
1335                 *vpsave = cpu_to_le32(mic.len);
1336                 seclen = seclen - vlen + mic.len;
1337                 sec_hdr->sec_len = cpu_to_le32(seclen);
1338                 rs->rs_repdata_len += size_round(seclen);
1339                 break;
1340         case  PTLRPC_GSS_SVC_PRIVACY:
1341                 vp = (__u32 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1342                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr);
1343                 seclen = vlen;
1344
1345                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1346                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_PRIV);
1347                 sec_hdr->msg_len = cpu_to_le32(0);
1348
1349                 /* standard gss hdr */
1350                 LASSERT(vlen >= 7 * 4);
1351                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1352                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1353                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1354                 *vp++ = cpu_to_le32(gc->gc_seq);
1355                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_PRIVACY);
1356                 *vp++ = 0;      /* fake ctx handle */
1357                 vpsave = vp++;  /* reserve size */
1358                 vlen -= 7 * 4;
1359
1360                 msg_buf.buf = (__u8 *) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1361                 msg_buf.buflen = req->rq_replen + GSS_PRIVBUF_PREFIX_LEN +
1362                                  GSS_PRIVBUF_SUFFIX_LEN;
1363                 msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1364                 msg_buf.datalen = req->rq_replen;
1365
1366                 cipher_buf.data = (__u8 *) vp;
1367                 cipher_buf.len = vlen;
1368
1369                 major = kgss_wrap(rscp->mechctx, GSS_C_QOP_DEFAULT,
1370                                 &msg_buf, &cipher_buf);
1371                 if (major) {
1372                         CERROR("failed to wrap: 0x%x\n", major);
1373                         GOTO(out, ret = -EINVAL);
1374                 }
1375
1376                 *vpsave = cpu_to_le32(cipher_buf.len);
1377                 seclen = seclen - vlen + cipher_buf.len;
1378                 sec_hdr->sec_len = cpu_to_le32(seclen);
1379                 rs->rs_repdata_len += size_round(seclen);
1380                 break;
1381         default:
1382                 CERROR("Unknown service %d\n", gc->gc_svc);
1383                 GOTO(out, ret = -EINVAL);
1384         }
1385         ret = 0;
1386 out:
1387         rsc_put(&rscp->h, &rsc_cache);
1388
1389         RETURN(ret);
1390 }
1391
1392 static
1393 void gss_svcsec_cleanup_req(struct ptlrpc_svcsec *svcsec,
1394                             struct ptlrpc_request *req)
1395 {
1396         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1397
1398         if (!gsd) {
1399                 CDEBUG(D_SEC, "no svc_data present. do nothing\n");
1400                 return;
1401         }
1402
1403         /* gsd->clclred.gc_ctx is NOT allocated, just set pointer
1404          * to the incoming packet buffer, so don't need free it
1405          */
1406         OBD_FREE(gsd, sizeof(*gsd));
1407         req->rq_sec_svcdata = NULL;
1408         return;
1409 }
1410
1411 static
1412 int gss_svcsec_est_payload(struct ptlrpc_svcsec *svcsec,
1413                            struct ptlrpc_request *req,
1414                            int msgsize)
1415 {
1416         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1417         ENTRY;
1418
1419         /* just return the pre-set reserve_len for init/fini/err cases.
1420          */
1421         LASSERT(svcdata);
1422         if (svcdata->is_init) {
1423                 CDEBUG(D_SEC, "is_init, reserver size %d(%d)\n",
1424                        size_round(svcdata->reserve_len),
1425                        svcdata->reserve_len);
1426                 LASSERT(svcdata->reserve_len);
1427                 LASSERT(svcdata->reserve_len % 4 == 0);
1428                 RETURN(size_round(svcdata->reserve_len));
1429         } else if (svcdata->is_err_notify) {
1430                 CDEBUG(D_SEC, "is_err_notify, reserver size %d(%d)\n",
1431                        size_round(svcdata->reserve_len),
1432                        svcdata->reserve_len);
1433                 RETURN(size_round(svcdata->reserve_len));
1434         } else if (svcdata->is_fini) {
1435                 CDEBUG(D_SEC, "is_fini, reserver size 0\n");
1436                 RETURN(0);
1437         } else {
1438                 if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_NONE ||
1439                     svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_INTEGRITY)
1440                         RETURN(size_round(GSS_MAX_AUTH_PAYLOAD));
1441                 else if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1442                         RETURN(size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1443                                             GSS_PRIVBUF_PREFIX_LEN +
1444                                             GSS_PRIVBUF_SUFFIX_LEN));
1445                 else {
1446                         CERROR("unknown gss svc %u\n", svcdata->clcred.gc_svc);
1447                         *((int *)0) = 0;
1448                         LBUG();
1449                 }
1450         }
1451         RETURN(0);
1452 }
1453
1454 int gss_svcsec_alloc_repbuf(struct ptlrpc_svcsec *svcsec,
1455                             struct ptlrpc_request *req,
1456                             int msgsize)
1457 {
1458         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1459         struct ptlrpc_reply_state *rs;
1460         int msg_payload, sec_payload;
1461         int privacy, rc;
1462         ENTRY;
1463
1464         /* determine the security type: none/auth or priv, we have
1465          * different pack scheme for them.
1466          * init/fini/err will always be treated as none/auth.
1467          */
1468         LASSERT(gsd);
1469         if (!gsd->is_init && !gsd->is_init_continue &&
1470             !gsd->is_fini && !gsd->is_err_notify &&
1471             gsd->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1472                 privacy = 1;
1473         else
1474                 privacy = 0;
1475
1476         msg_payload = privacy ? 0 : msgsize;
1477         sec_payload = gss_svcsec_est_payload(svcsec, req, msgsize);
1478
1479         rc = svcsec_alloc_reply_state(req, msg_payload, sec_payload);
1480         if (rc)
1481                 RETURN(rc);
1482
1483         rs = req->rq_reply_state;
1484         LASSERT(rs);
1485         rs->rs_msg_len = msgsize;
1486
1487         if (privacy) {
1488                 /* we can choose to let msg simply point to the rear of the
1489                  * buffer, which lead to buffer overlap when doing encryption.
1490                  * usually it's ok and it indeed passed all existing tests.
1491                  * but not sure if there will be subtle problems in the future.
1492                  * so right now we choose to alloc another new buffer. we'll
1493                  * see how it works.
1494                  */
1495 #if 0
1496                 rs->rs_msg = (struct lustre_msg *)
1497                              (rs->rs_repbuf + rs->rs_repbuf_len -
1498                               msgsize - GSS_PRIVBUF_SUFFIX_LEN);
1499 #endif
1500                 char *msgbuf;
1501
1502                 msgsize += GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1503                 OBD_ALLOC(msgbuf, msgsize);
1504                 if (!msgbuf) {
1505                         CERROR("can't alloc %d\n", msgsize);
1506                         svcsec_free_reply_state(rs);
1507                         req->rq_reply_state = NULL;
1508                         RETURN(-ENOMEM);
1509                 }
1510                 rs->rs_msg = (struct lustre_msg *)
1511                                 (msgbuf + GSS_PRIVBUF_PREFIX_LEN);
1512         }
1513
1514         req->rq_repmsg = rs->rs_msg;
1515
1516         RETURN(0);
1517 }
1518
1519 static
1520 void gss_svcsec_free_repbuf(struct ptlrpc_svcsec *svcsec,
1521                             struct ptlrpc_reply_state *rs)
1522 {
1523         unsigned long p1 = (unsigned long) rs->rs_msg;
1524         unsigned long p2 = (unsigned long) rs->rs_buf;
1525
1526         LASSERT(rs->rs_buf);
1527         LASSERT(rs->rs_msg);
1528         LASSERT(rs->rs_msg_len);
1529
1530         if (p1 < p2 || p1 >= p2 + rs->rs_buf_len) {
1531                 char *start = (char*) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1532                 int size = rs->rs_msg_len + GSS_PRIVBUF_PREFIX_LEN +
1533                            GSS_PRIVBUF_SUFFIX_LEN;
1534                 OBD_FREE(start, size);
1535         }
1536
1537         svcsec_free_reply_state(rs);
1538 }
1539
1540 struct ptlrpc_svcsec svcsec_gss = {
1541         .pss_owner              = THIS_MODULE,
1542         .pss_name               = "GSS_SVCSEC",
1543         .pss_flavor             = {PTLRPC_SEC_GSS, 0},
1544         .accept                 = gss_svcsec_accept,
1545         .authorize              = gss_svcsec_authorize,
1546         .alloc_repbuf           = gss_svcsec_alloc_repbuf,
1547         .free_repbuf            = gss_svcsec_free_repbuf,
1548         .cleanup_req            = gss_svcsec_cleanup_req,
1549 };
1550
1551 /* XXX hacking */
1552 void lgss_svc_cache_purge_all(void)
1553 {
1554         cache_purge(&rsi_cache);
1555         cache_purge(&rsc_cache);
1556 }
1557 EXPORT_SYMBOL(lgss_svc_cache_purge_all);
1558
1559 void lgss_svc_cache_flush(__u32 uid)
1560 {
1561         rsc_flush(uid);
1562 }
1563 EXPORT_SYMBOL(lgss_svc_cache_flush);
1564
1565 int gss_svc_init(void)
1566 {
1567         int rc;
1568
1569         rc = svcsec_register(&svcsec_gss);
1570         if (!rc) {
1571                 cache_register(&rsc_cache);
1572                 cache_register(&rsi_cache);
1573         }
1574         return rc;
1575 }
1576
1577 void gss_svc_exit(void)
1578 {
1579         int rc;
1580         if ((rc = cache_unregister(&rsi_cache)))
1581                 CERROR("unregister rsi cache: %d\n", rc);
1582         if ((rc = cache_unregister(&rsc_cache)))
1583                 CERROR("unregister rsc cache: %d\n", rc);
1584         if ((rc = svcsec_unregister(&svcsec_gss)))
1585                 CERROR("unregister svcsec_gss: %d\n", rc);
1586 }