Whamcloud - gitweb
da84fad0aebe6cdf6b462caa5943e99e7810ceb0
[fs/lustre-release.git] / lustre / ptlrpc / gss / gss_svc_upcall.c
1 /*
2  * Modifications for Lustre
3  *
4  * Copyright (c) 2007, 2010, Oracle and/or its affiliates. All rights reserved.
5  *
6  * Copyright (c) 2012, 2014, Intel Corporation.
7  *
8  * Author: Eric Mei <ericm@clusterfs.com>
9  */
10
11 /*
12  * Neil Brown <neilb@cse.unsw.edu.au>
13  * J. Bruce Fields <bfields@umich.edu>
14  * Andy Adamson <andros@umich.edu>
15  * Dug Song <dugsong@monkey.org>
16  *
17  * RPCSEC_GSS server authentication.
18  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
19  * (gssapi)
20  *
21  * The RPCSEC_GSS involves three stages:
22  *  1/ context creation
23  *  2/ data exchange
24  *  3/ context destruction
25  *
26  * Context creation is handled largely by upcalls to user-space.
27  *  In particular, GSS_Accept_sec_context is handled by an upcall
28  * Data exchange is handled entirely within the kernel
29  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
30  * Context destruction is handled in-kernel
31  *  GSS_Delete_sec_context is in-kernel
32  *
33  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
34  * The context handle and gss_token are used as a key into the rpcsec_init cache.
35  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
36  * being major_status, minor_status, context_handle, reply_token.
37  * These are sent back to the client.
38  * Sequence window management is handled by the kernel.  The window size if currently
39  * a compile time constant.
40  *
41  * When user-space is happy that a context is established, it places an entry
42  * in the rpcsec_context cache. The key for this cache is the context_handle.
43  * The content includes:
44  *   uid/gidlist - for determining access rights
45  *   mechanism type
46  *   mechanism specific information, such as a key
47  *
48  */
49
50 #define DEBUG_SUBSYSTEM S_SEC
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/random.h>
55 #include <linux/slab.h>
56 #include <linux/hash.h>
57 #include <linux/mutex.h>
58 #include <linux/sunrpc/cache.h>
59 #include <net/sock.h>
60
61 #include <obd.h>
62 #include <obd_class.h>
63 #include <obd_support.h>
64 #include <lustre_import.h>
65 #include <lustre_net.h>
66 #include <lustre_nodemap.h>
67 #include <lustre_sec.h>
68
69 #include "gss_err.h"
70 #include "gss_internal.h"
71 #include "gss_api.h"
72 #include "gss_crypto.h"
73
74 #define GSS_SVC_UPCALL_TIMEOUT  (20)
75
76 static DEFINE_SPINLOCK(__ctx_index_lock);
77 static __u64 __ctx_index;
78
79 unsigned int krb5_allow_old_client_csum;
80
81 __u64 gss_get_next_ctx_index(void)
82 {
83         __u64 idx;
84
85         spin_lock(&__ctx_index_lock);
86         idx = __ctx_index++;
87         spin_unlock(&__ctx_index_lock);
88
89         return idx;
90 }
91
92 static inline unsigned long hash_mem(char *buf, int length, int bits)
93 {
94         unsigned long hash = 0;
95         unsigned long l = 0;
96         int len = 0;
97         unsigned char c;
98
99         do {
100                 if (len == length) {
101                         c = (char) len;
102                         len = -1;
103                 } else
104                         c = *buf++;
105
106                 l = (l << 8) | c;
107                 len++;
108
109                 if ((len & (BITS_PER_LONG/8-1)) == 0)
110                         hash = hash_long(hash^l, BITS_PER_LONG);
111         } while (len);
112
113         return hash >> (BITS_PER_LONG - bits);
114 }
115
116 /****************************************
117  * rpc sec init (rsi) cache *
118  ****************************************/
119
120 #define RSI_HASHBITS    (6)
121 #define RSI_HASHMAX     (1 << RSI_HASHBITS)
122 #define RSI_HASHMASK    (RSI_HASHMAX - 1)
123
124 struct rsi {
125         struct cache_head       h;
126         __u32                   lustre_svc;
127         lnet_nid_t              nid4; /* FIXME Support larger NID */
128         char                    nm_name[LUSTRE_NODEMAP_NAME_LENGTH + 1];
129         wait_queue_head_t       waitq;
130         rawobj_t                in_handle, in_token;
131         rawobj_t                out_handle, out_token;
132         int                     major_status, minor_status;
133 #ifdef HAVE_CACHE_HASH_SPINLOCK
134         struct rcu_head         rcu_head;
135 #endif
136 };
137
138 #ifdef HAVE_CACHE_HEAD_HLIST
139 static struct hlist_head rsi_table[RSI_HASHMAX];
140 #else
141 static struct cache_head *rsi_table[RSI_HASHMAX];
142 #endif
143 static struct cache_detail rsi_cache;
144 static struct rsi *rsi_update(struct rsi *new, struct rsi *old);
145 static struct rsi *rsi_lookup(struct rsi *item);
146
147 #ifdef HAVE_CACHE_DETAIL_WRITERS
148 static inline int channel_users(struct cache_detail *cd)
149 {
150         return atomic_read(&cd->writers);
151 }
152 #else
153 static inline int channel_users(struct cache_detail *cd)
154 {
155         return atomic_read(&cd->readers);
156 }
157 #endif
158
159 static inline int rsi_hash(struct rsi *item)
160 {
161         return hash_mem((char *)item->in_handle.data, item->in_handle.len,
162                         RSI_HASHBITS) ^
163                hash_mem((char *)item->in_token.data, item->in_token.len,
164                         RSI_HASHBITS);
165 }
166
167 static inline int __rsi_match(struct rsi *item, struct rsi *tmp)
168 {
169         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
170                 rawobj_equal(&item->in_token, &tmp->in_token));
171 }
172
173 static void rsi_free(struct rsi *rsi)
174 {
175         rawobj_free(&rsi->in_handle);
176         rawobj_free(&rsi->in_token);
177         rawobj_free(&rsi->out_handle);
178         rawobj_free(&rsi->out_token);
179 }
180
181 /* See handle_channel_req() userspace for where the upcall data is read */
182 static void rsi_request(struct cache_detail *cd,
183                         struct cache_head *h,
184                         char **bpp, int *blen)
185 {
186         struct rsi *rsi = container_of(h, struct rsi, h);
187         __u64 index = 0;
188
189         /* if in_handle is null, provide kernel suggestion */
190         if (rsi->in_handle.len == 0)
191                 index = gss_get_next_ctx_index();
192
193         qword_addhex(bpp, blen, (char *) &rsi->lustre_svc,
194                         sizeof(rsi->lustre_svc));
195         qword_addhex(bpp, blen, (char *) &rsi->nid4, sizeof(rsi->nid4));
196         qword_addhex(bpp, blen, (char *) &index, sizeof(index));
197         qword_addhex(bpp, blen, (char *) rsi->nm_name,
198                      strlen(rsi->nm_name) + 1);
199         qword_addhex(bpp, blen, rsi->in_handle.data, rsi->in_handle.len);
200         qword_addhex(bpp, blen, rsi->in_token.data, rsi->in_token.len);
201         (*bpp)[-1] = '\n';
202 }
203
204 static inline void __rsi_init(struct rsi *new, struct rsi *item)
205 {
206         new->out_handle = RAWOBJ_EMPTY;
207         new->out_token = RAWOBJ_EMPTY;
208
209         new->in_handle = item->in_handle;
210         item->in_handle = RAWOBJ_EMPTY;
211         new->in_token = item->in_token;
212         item->in_token = RAWOBJ_EMPTY;
213
214         new->lustre_svc = item->lustre_svc;
215         new->nid4 = item->nid4;
216         memcpy(new->nm_name, item->nm_name, sizeof(item->nm_name));
217         init_waitqueue_head(&new->waitq);
218 }
219
220 static inline void __rsi_update(struct rsi *new, struct rsi *item)
221 {
222         LASSERT(new->out_handle.len == 0);
223         LASSERT(new->out_token.len == 0);
224
225         new->out_handle = item->out_handle;
226         item->out_handle = RAWOBJ_EMPTY;
227         new->out_token = item->out_token;
228         item->out_token = RAWOBJ_EMPTY;
229
230         new->major_status = item->major_status;
231         new->minor_status = item->minor_status;
232 }
233
234 #ifdef HAVE_CACHE_HASH_SPINLOCK
235 static void rsi_free_rcu(struct rcu_head *head)
236 {
237         struct rsi *rsi = container_of(head, struct rsi, rcu_head);
238
239 #ifdef HAVE_CACHE_HEAD_HLIST
240         LASSERT(hlist_unhashed(&rsi->h.cache_list));
241 #else
242         LASSERT(rsi->h.next == NULL);
243 #endif
244         rsi_free(rsi);
245         OBD_FREE_PTR(rsi);
246 }
247
248 static void rsi_put(struct kref *ref)
249 {
250         struct rsi *rsi = container_of(ref, struct rsi, h.ref);
251
252         call_rcu(&rsi->rcu_head, rsi_free_rcu);
253 }
254 #else /* !HAVE_CACHE_HASH_SPINLOCK */
255 static void rsi_put(struct kref *ref)
256 {
257         struct rsi *rsi = container_of(ref, struct rsi, h.ref);
258
259 #ifdef HAVE_CACHE_HEAD_HLIST
260         LASSERT(hlist_unhashed(&rsi->h.cache_list));
261 #else
262         LASSERT(rsi->h.next == NULL);
263 #endif
264         rsi_free(rsi);
265         OBD_FREE_PTR(rsi);
266 }
267 #endif /* HAVE_CACHE_HASH_SPINLOCK */
268
269 static int rsi_match(struct cache_head *a, struct cache_head *b)
270 {
271         struct rsi *item = container_of(a, struct rsi, h);
272         struct rsi *tmp = container_of(b, struct rsi, h);
273
274         return __rsi_match(item, tmp);
275 }
276
277 static void rsi_init(struct cache_head *cnew, struct cache_head *citem)
278 {
279         struct rsi *new = container_of(cnew, struct rsi, h);
280         struct rsi *item = container_of(citem, struct rsi, h);
281
282         __rsi_init(new, item);
283 }
284
285 static void update_rsi(struct cache_head *cnew, struct cache_head *citem)
286 {
287         struct rsi *new = container_of(cnew, struct rsi, h);
288         struct rsi *item = container_of(citem, struct rsi, h);
289
290         __rsi_update(new, item);
291 }
292
293 static struct cache_head *rsi_alloc(void)
294 {
295         struct rsi *rsi;
296
297         OBD_ALLOC_PTR(rsi);
298         if (rsi) 
299                 return &rsi->h;
300         else
301                 return NULL;
302 }
303
304 static int rsi_parse(struct cache_detail *cd, char *mesg, int mlen)
305 {
306         char *buf = mesg;
307         int len;
308         struct rsi rsii, *rsip = NULL;
309         time64_t expiry;
310         int status = -EINVAL;
311         ENTRY;
312
313         memset(&rsii, 0, sizeof(rsii));
314
315         /* handle */
316         len = qword_get(&mesg, buf, mlen);
317         if (len < 0)
318                 goto out;
319         if (rawobj_alloc(&rsii.in_handle, buf, len)) {
320                 status = -ENOMEM;
321                 goto out;
322         }
323
324         /* token */
325         len = qword_get(&mesg, buf, mlen);
326         if (len < 0)
327                 goto out;
328         if (rawobj_alloc(&rsii.in_token, buf, len)) {
329                 status = -ENOMEM;
330                 goto out;
331         }
332
333         rsip = rsi_lookup(&rsii);
334         if (!rsip)
335                 goto out;
336         if (!test_bit(CACHE_PENDING, &rsip->h.flags)) {
337                 /* If this is not a pending request, it probably means
338                  * someone wrote arbitrary data to the init channel.
339                  * Directly return -EINVAL in this case.
340                  */
341                 status = -EINVAL;
342                 goto out;
343         }
344
345         rsii.h.flags = 0;
346         /* expiry */
347         expiry = get_expiry(&mesg);
348         if (expiry == 0)
349                 goto out;
350
351         len = qword_get(&mesg, buf, mlen);
352         if (len <= 0)
353                 goto out;
354
355         /* major */
356         status = kstrtoint(buf, 10, &rsii.major_status);
357         if (status)
358                 goto out;
359
360         /* minor */
361         len = qword_get(&mesg, buf, mlen);
362         if (len <= 0) {
363                 status = -EINVAL;
364                 goto out;
365         }
366
367         status = kstrtoint(buf, 10, &rsii.minor_status);
368         if (status)
369                 goto out;
370
371         /* out_handle */
372         len = qword_get(&mesg, buf, mlen);
373         if (len < 0)
374                 goto out;
375         if (rawobj_alloc(&rsii.out_handle, buf, len)) {
376                 status = -ENOMEM;
377                 goto out;
378         }
379
380         /* out_token */
381         len = qword_get(&mesg, buf, mlen);
382         if (len < 0)
383                 goto out;
384         if (rawobj_alloc(&rsii.out_token, buf, len)) {
385                 status = -ENOMEM;
386                 goto out;
387         }
388
389         rsii.h.expiry_time = expiry;
390         rsip = rsi_update(&rsii, rsip);
391         status = 0;
392 out:
393         rsi_free(&rsii);
394         if (rsip) {
395                 wake_up(&rsip->waitq);
396                 cache_put(&rsip->h, &rsi_cache);
397         } else {
398                 status = -ENOMEM;
399         }
400
401         if (status)
402                 CERROR("rsi parse error %d\n", status);
403         RETURN(status);
404 }
405
406 static struct cache_detail rsi_cache = {
407         .hash_size      = RSI_HASHMAX,
408         .hash_table     = rsi_table,
409         .name           = "auth.sptlrpc.init",
410         .cache_put      = rsi_put,
411         .cache_request  = rsi_request,
412         .cache_upcall   = sunrpc_cache_pipe_upcall,
413         .cache_parse    = rsi_parse,
414         .match          = rsi_match,
415         .init           = rsi_init,
416         .update         = update_rsi,
417         .alloc          = rsi_alloc,
418 };
419
420 static struct rsi *rsi_lookup(struct rsi *item)
421 {
422         struct cache_head *ch;
423         int hash = rsi_hash(item);
424
425         ch = sunrpc_cache_lookup(&rsi_cache, &item->h, hash);
426         if (ch)
427                 return container_of(ch, struct rsi, h);
428         else
429                 return NULL;
430 }
431
432 static struct rsi *rsi_update(struct rsi *new, struct rsi *old)
433 {
434         struct cache_head *ch;
435         int hash = rsi_hash(new);
436
437         ch = sunrpc_cache_update(&rsi_cache, &new->h, &old->h, hash);
438         if (ch)
439                 return container_of(ch, struct rsi, h);
440         else
441                 return NULL;
442 }
443
444 /****************************************
445  * rpc sec context (rsc) cache                            *
446  ****************************************/
447
448 #define RSC_HASHBITS    (10)
449 #define RSC_HASHMAX     (1 << RSC_HASHBITS)
450 #define RSC_HASHMASK    (RSC_HASHMAX - 1)
451
452 struct rsc {
453         struct cache_head       h;
454         struct obd_device      *target;
455         rawobj_t                handle;
456         struct gss_svc_ctx      ctx;
457 #ifdef HAVE_CACHE_HASH_SPINLOCK
458         struct rcu_head         rcu_head;
459 #endif
460 };
461
462 #ifdef HAVE_CACHE_HEAD_HLIST
463 static struct hlist_head rsc_table[RSC_HASHMAX];
464 #else
465 static struct cache_head *rsc_table[RSC_HASHMAX];
466 #endif
467 static struct cache_detail rsc_cache;
468 static struct rsc *rsc_update(struct rsc *new, struct rsc *old);
469 static struct rsc *rsc_lookup(struct rsc *item);
470
471 static void rsc_free(struct rsc *rsci)
472 {
473         rawobj_free(&rsci->handle);
474         rawobj_free(&rsci->ctx.gsc_rvs_hdl);
475         lgss_delete_sec_context(&rsci->ctx.gsc_mechctx);
476 }
477
478 static inline int rsc_hash(struct rsc *rsci)
479 {
480         return hash_mem((char *)rsci->handle.data,
481                         rsci->handle.len, RSC_HASHBITS);
482 }
483
484 static inline int __rsc_match(struct rsc *new, struct rsc *tmp)
485 {
486         return rawobj_equal(&new->handle, &tmp->handle);
487 }
488
489 static inline void __rsc_init(struct rsc *new, struct rsc *tmp)
490 {
491         new->handle = tmp->handle;
492         tmp->handle = RAWOBJ_EMPTY;
493
494         new->target = NULL;
495         memset(&new->ctx, 0, sizeof(new->ctx));
496         new->ctx.gsc_rvs_hdl = RAWOBJ_EMPTY;
497 }
498
499 static inline void __rsc_update(struct rsc *new, struct rsc *tmp)
500 {
501         new->ctx = tmp->ctx;
502         memset(&tmp->ctx, 0, sizeof(tmp->ctx));
503         tmp->ctx.gsc_rvs_hdl = RAWOBJ_EMPTY;
504         tmp->ctx.gsc_mechctx = NULL;
505         tmp->target = NULL;
506
507         memset(&new->ctx.gsc_seqdata, 0, sizeof(new->ctx.gsc_seqdata));
508         spin_lock_init(&new->ctx.gsc_seqdata.ssd_lock);
509 }
510
511 #ifdef HAVE_CACHE_HASH_SPINLOCK
512 static void rsc_free_rcu(struct rcu_head *head)
513 {
514         struct rsc *rsci = container_of(head, struct rsc, rcu_head);
515
516 #ifdef HAVE_CACHE_HEAD_HLIST
517         LASSERT(hlist_unhashed(&rsci->h.cache_list));
518 #else
519         LASSERT(rsci->h.next == NULL);
520 #endif
521         rawobj_free(&rsci->handle);
522         OBD_FREE_PTR(rsci);
523 }
524
525 static void rsc_put(struct kref *ref)
526 {
527         struct rsc *rsci = container_of(ref, struct rsc, h.ref);
528
529         rawobj_free(&rsci->ctx.gsc_rvs_hdl);
530         lgss_delete_sec_context(&rsci->ctx.gsc_mechctx);
531         call_rcu(&rsci->rcu_head, rsc_free_rcu);
532 }
533 #else /* !HAVE_CACHE_HASH_SPINLOCK */
534 static void rsc_put(struct kref *ref)
535 {
536         struct rsc *rsci = container_of(ref, struct rsc, h.ref);
537
538 #ifdef HAVE_CACHE_HEAD_HLIST
539         LASSERT(hlist_unhashed(&rsci->h.cache_list));
540 #else
541         LASSERT(rsci->h.next == NULL);
542 #endif
543         rsc_free(rsci);
544         OBD_FREE_PTR(rsci);
545 }
546 #endif /* HAVE_CACHE_HASH_SPINLOCK */
547
548 static int rsc_match(struct cache_head *a, struct cache_head *b)
549 {
550         struct rsc *new = container_of(a, struct rsc, h);
551         struct rsc *tmp = container_of(b, struct rsc, h);
552
553         return __rsc_match(new, tmp);
554 }
555
556 static void rsc_init(struct cache_head *cnew, struct cache_head *ctmp)
557 {
558         struct rsc *new = container_of(cnew, struct rsc, h);
559         struct rsc *tmp = container_of(ctmp, struct rsc, h);
560
561         __rsc_init(new, tmp);
562 }
563
564 static void update_rsc(struct cache_head *cnew, struct cache_head *ctmp)
565 {
566         struct rsc *new = container_of(cnew, struct rsc, h);
567         struct rsc *tmp = container_of(ctmp, struct rsc, h);
568
569         __rsc_update(new, tmp);
570 }
571
572 static struct cache_head * rsc_alloc(void)
573 {
574         struct rsc *rsc;
575
576         OBD_ALLOC_PTR(rsc);
577         if (rsc)
578                 return &rsc->h;
579         else
580                 return NULL;
581 }
582
583 static int rsc_parse(struct cache_detail *cd, char *mesg, int mlen)
584 {
585         char                *buf = mesg;
586         int                  len, rv, tmp_int;
587         struct rsc           rsci, *rscp = NULL;
588         time64_t expiry;
589         int                  status = -EINVAL;
590         struct gss_api_mech *gm = NULL;
591
592         memset(&rsci, 0, sizeof(rsci));
593
594         /* context handle */
595         len = qword_get(&mesg, buf, mlen);
596         if (len < 0) goto out;
597         status = -ENOMEM;
598         if (rawobj_alloc(&rsci.handle, buf, len))
599                 goto out;
600
601         rsci.h.flags = 0;
602         /* expiry */
603         expiry = get_expiry(&mesg);
604         status = -EINVAL;
605         if (expiry == 0)
606                 goto out;
607
608         /* remote flag */
609         rv = get_int(&mesg, &tmp_int);
610         if (rv) {
611                 CERROR("fail to get remote flag\n");
612                 goto out;
613         }
614         rsci.ctx.gsc_remote = (tmp_int != 0);
615
616         /* root user flag */
617         rv = get_int(&mesg, &tmp_int);
618         if (rv) {
619                 CERROR("fail to get root user flag\n");
620                 goto out;
621         }
622         rsci.ctx.gsc_usr_root = (tmp_int != 0);
623
624         /* mds user flag */
625         rv = get_int(&mesg, &tmp_int);
626         if (rv) {
627                 CERROR("fail to get mds user flag\n");
628                 goto out;
629         }
630         rsci.ctx.gsc_usr_mds = (tmp_int != 0);
631
632         /* oss user flag */
633         rv = get_int(&mesg, &tmp_int);
634         if (rv) {
635                 CERROR("fail to get oss user flag\n");
636                 goto out;
637         }
638         rsci.ctx.gsc_usr_oss = (tmp_int != 0);
639
640         /* mapped uid */
641         rv = get_int(&mesg, (int *) &rsci.ctx.gsc_mapped_uid);
642         if (rv) {
643                 CERROR("fail to get mapped uid\n");
644                 goto out;
645         }
646
647         rscp = rsc_lookup(&rsci);
648         if (!rscp)
649                 goto out;
650
651         /* uid, or NEGATIVE */
652         rv = get_int(&mesg, (int *) &rsci.ctx.gsc_uid);
653         if (rv == -EINVAL)
654                 goto out;
655         if (rv == -ENOENT) {
656                 CERROR("NOENT? set rsc entry negative\n");
657                 set_bit(CACHE_NEGATIVE, &rsci.h.flags);
658         } else {
659                 rawobj_t tmp_buf;
660                 time64_t ctx_expiry;
661
662                 /* gid */
663                 if (get_int(&mesg, (int *) &rsci.ctx.gsc_gid))
664                         goto out;
665
666                 /* mech name */
667                 len = qword_get(&mesg, buf, mlen);
668                 if (len < 0)
669                         goto out;
670                 gm = lgss_name_to_mech(buf);
671                 status = -EOPNOTSUPP;
672                 if (!gm)
673                         goto out;
674
675                 status = -EINVAL;
676                 /* mech-specific data: */
677                 len = qword_get(&mesg, buf, mlen);
678                 if (len < 0)
679                         goto out;
680
681                 tmp_buf.len = len;
682                 tmp_buf.data = (unsigned char *)buf;
683                 if (lgss_import_sec_context(&tmp_buf, gm,
684                                             &rsci.ctx.gsc_mechctx))
685                         goto out;
686
687                 /* set to seconds since machine booted */
688                 expiry = ktime_get_seconds();
689
690                 /* currently the expiry time passed down from user-space
691                  * is invalid, here we retrive it from mech.
692                  */
693                 if (lgss_inquire_context(rsci.ctx.gsc_mechctx, &ctx_expiry)) {
694                         CERROR("unable to get expire time, drop it\n");
695                         goto out;
696                 }
697
698                 /* ctx_expiry is the number of seconds since Jan 1 1970.
699                  * We want just the  number of seconds into the future.
700                  */
701                 expiry += ctx_expiry - ktime_get_real_seconds();
702         }
703
704         rsci.h.expiry_time = expiry;
705         rscp = rsc_update(&rsci, rscp);
706         status = 0;
707 out:
708         if (gm)
709                 lgss_mech_put(gm);
710         rsc_free(&rsci);
711         if (rscp)
712                 cache_put(&rscp->h, &rsc_cache);
713         else
714                 status = -ENOMEM;
715
716         if (status)
717                 CERROR("parse rsc error %d\n", status);
718         return status;
719 }
720
721 static struct cache_detail rsc_cache = {
722         .hash_size      = RSC_HASHMAX,
723         .hash_table     = rsc_table,
724         .name           = "auth.sptlrpc.context",
725         .cache_put      = rsc_put,
726         .cache_parse    = rsc_parse,
727         .match          = rsc_match,
728         .init           = rsc_init,
729         .update         = update_rsc,
730         .alloc          = rsc_alloc,
731 };
732
733 static struct rsc *rsc_lookup(struct rsc *item)
734 {
735         struct cache_head *ch;
736         int                hash = rsc_hash(item);
737
738         ch = sunrpc_cache_lookup(&rsc_cache, &item->h, hash);
739         if (ch)
740                 return container_of(ch, struct rsc, h);
741         else
742                 return NULL;
743 }
744
745 static struct rsc *rsc_update(struct rsc *new, struct rsc *old)
746 {
747         struct cache_head *ch;
748         int                hash = rsc_hash(new);
749
750         ch = sunrpc_cache_update(&rsc_cache, &new->h, &old->h, hash);
751         if (ch)
752                 return container_of(ch, struct rsc, h);
753         else
754                 return NULL;
755 }
756
757 #define COMPAT_RSC_PUT(item, cd)        cache_put((item), (cd))
758
759 /****************************************
760  * rsc cache flush                      *
761  ****************************************/
762
763 static struct rsc *gss_svc_searchbyctx(rawobj_t *handle)
764 {
765         struct rsc  rsci;
766         struct rsc *found;
767
768         memset(&rsci, 0, sizeof(rsci));
769         if (rawobj_dup(&rsci.handle, handle))
770                 return NULL;
771
772         found = rsc_lookup(&rsci);
773         rsc_free(&rsci);
774         if (!found)
775                 return NULL;
776         if (cache_check(&rsc_cache, &found->h, NULL))
777                 return NULL;
778         return found;
779 }
780
781 int gss_svc_upcall_install_rvs_ctx(struct obd_import *imp,
782                                    struct gss_sec *gsec,
783                                    struct gss_cli_ctx *gctx)
784 {
785         struct rsc      rsci, *rscp = NULL;
786         time64_t ctx_expiry;
787         __u32           major;
788         int             rc;
789         ENTRY;
790
791         memset(&rsci, 0, sizeof(rsci));
792
793         if (rawobj_alloc(&rsci.handle, (char *) &gsec->gs_rvs_hdl,
794                          sizeof(gsec->gs_rvs_hdl)))
795                 GOTO(out, rc = -ENOMEM);
796
797         rscp = rsc_lookup(&rsci);
798         if (rscp == NULL)
799                 GOTO(out, rc = -ENOMEM);
800
801         major = lgss_copy_reverse_context(gctx->gc_mechctx,
802                                           &rsci.ctx.gsc_mechctx);
803         if (major != GSS_S_COMPLETE)
804                 GOTO(out, rc = -ENOMEM);
805
806         if (lgss_inquire_context(rsci.ctx.gsc_mechctx, &ctx_expiry)) {
807                 CERROR("unable to get expire time, drop it\n");
808                 GOTO(out, rc = -EINVAL);
809         }
810         rsci.h.expiry_time = ctx_expiry;
811
812         switch (imp->imp_obd->u.cli.cl_sp_to) {
813         case LUSTRE_SP_MDT:
814                 rsci.ctx.gsc_usr_mds = 1;
815                 break;
816         case LUSTRE_SP_OST:
817                 rsci.ctx.gsc_usr_oss = 1;
818                 break;
819         case LUSTRE_SP_CLI:
820                 rsci.ctx.gsc_usr_root = 1;
821                 break;
822         case LUSTRE_SP_MGS:
823                 /* by convention, all 3 set to 1 means MGS */
824                 rsci.ctx.gsc_usr_mds = 1;
825                 rsci.ctx.gsc_usr_oss = 1;
826                 rsci.ctx.gsc_usr_root = 1;
827                 break;
828         default:
829                 break;
830         }
831
832         rscp = rsc_update(&rsci, rscp);
833         if (rscp == NULL)
834                 GOTO(out, rc = -ENOMEM);
835
836         rscp->target = imp->imp_obd;
837         rawobj_dup(&gctx->gc_svc_handle, &rscp->handle);
838
839         CWARN("create reverse svc ctx %p to %s: idx %#llx\n",
840               &rscp->ctx, obd2cli_tgt(imp->imp_obd), gsec->gs_rvs_hdl);
841         rc = 0;
842 out:
843         if (rscp)
844                 cache_put(&rscp->h, &rsc_cache);
845         rsc_free(&rsci);
846
847         if (rc)
848                 CERROR("create reverse svc ctx: idx %#llx, rc %d\n",
849                        gsec->gs_rvs_hdl, rc);
850         RETURN(rc);
851 }
852
853 int gss_svc_upcall_expire_rvs_ctx(rawobj_t *handle)
854 {
855         const time64_t expire = 20;
856         struct rsc *rscp;
857
858         rscp = gss_svc_searchbyctx(handle);
859         if (rscp) {
860                 CDEBUG(D_SEC, "reverse svcctx %p (rsc %p) expire soon\n",
861                        &rscp->ctx, rscp);
862
863                 rscp->h.expiry_time = ktime_get_real_seconds() + expire;
864                 COMPAT_RSC_PUT(&rscp->h, &rsc_cache);
865         }
866         return 0;
867 }
868
869 int gss_svc_upcall_dup_handle(rawobj_t *handle, struct gss_svc_ctx *ctx)
870 {
871         struct rsc *rscp = container_of(ctx, struct rsc, ctx);
872
873         return rawobj_dup(handle, &rscp->handle);
874 }
875
876 int gss_svc_upcall_update_sequence(rawobj_t *handle, __u32 seq)
877 {
878         struct rsc             *rscp;
879
880         rscp = gss_svc_searchbyctx(handle);
881         if (rscp) {
882                 CDEBUG(D_SEC, "reverse svcctx %p (rsc %p) update seq to %u\n",
883                        &rscp->ctx, rscp, seq + 1);
884
885                 rscp->ctx.gsc_rvs_seq = seq + 1;
886                 COMPAT_RSC_PUT(&rscp->h, &rsc_cache);
887         }
888         return 0;
889 }
890
891 static struct cache_deferred_req* cache_upcall_defer(struct cache_req *req)
892 {
893         return NULL;
894 }
895 static struct cache_req cache_upcall_chandle = { cache_upcall_defer };
896
897 int gss_svc_upcall_handle_init(struct ptlrpc_request *req,
898                                struct gss_svc_reqctx *grctx,
899                                struct gss_wire_ctx *gw,
900                                struct obd_device *target,
901                                __u32 lustre_svc,
902                                rawobj_t *rvs_hdl,
903                                rawobj_t *in_token)
904 {
905         struct ptlrpc_reply_state *rs;
906         struct rsc                *rsci = NULL;
907         struct rsi                *rsip = NULL, rsikey;
908         wait_queue_entry_t wait;
909         int                        replen = sizeof(struct ptlrpc_body);
910         struct gss_rep_header     *rephdr;
911         int                        first_check = 1;
912         int                        rc = SECSVC_DROP;
913         struct lnet_nid primary;
914         ENTRY;
915
916         memset(&rsikey, 0, sizeof(rsikey));
917         rsikey.lustre_svc = lustre_svc;
918         /* In case of MR, rq_peer is not the NID from which request is received,
919          * but primary NID of peer.
920          * So we need LNetPrimaryNID(rq_source) to match what the clients uses.
921          */
922         lnet_nid4_to_nid(req->rq_source.nid, &primary);
923         LNetPrimaryNID(&primary);
924         rsikey.nid4 = lnet_nid_to_nid4(&primary);
925         nodemap_test_nid(lnet_nid_to_nid4(&req->rq_peer.nid), rsikey.nm_name,
926                          sizeof(rsikey.nm_name));
927
928         /* duplicate context handle. for INIT it always 0 */
929         if (rawobj_dup(&rsikey.in_handle, &gw->gw_handle)) {
930                 CERROR("fail to dup context handle\n");
931                 GOTO(out, rc);
932         }
933
934         if (rawobj_dup(&rsikey.in_token, in_token)) {
935                 CERROR("can't duplicate token\n");
936                 rawobj_free(&rsikey.in_handle);
937                 GOTO(out, rc);
938         }
939
940         rsip = rsi_lookup(&rsikey);
941         rsi_free(&rsikey);
942         if (!rsip) {
943                 CERROR("error in rsi_lookup.\n");
944
945                 if (!gss_pack_err_notify(req, GSS_S_FAILURE, 0))
946                         rc = SECSVC_COMPLETE;
947
948                 GOTO(out, rc);
949         }
950
951         cache_get(&rsip->h); /* take an extra ref */
952         init_wait(&wait);
953         add_wait_queue(&rsip->waitq, &wait);
954
955 cache_check:
956         /* Note each time cache_check() will drop a reference if return
957          * non-zero. We hold an extra reference on initial rsip, but must
958          * take care of following calls. */
959         rc = cache_check(&rsi_cache, &rsip->h, &cache_upcall_chandle);
960         switch (rc) {
961         case -ETIMEDOUT:
962         case -EAGAIN: {
963                 int valid;
964
965                 if (first_check) {
966                         first_check = 0;
967
968                         cache_read_lock(&rsi_cache);
969                         valid = test_bit(CACHE_VALID, &rsip->h.flags);
970                         if (valid == 0)
971                                 set_current_state(TASK_INTERRUPTIBLE);
972                         cache_read_unlock(&rsi_cache);
973
974                         if (valid == 0) {
975                                 unsigned long timeout;
976
977                                 timeout = cfs_time_seconds(GSS_SVC_UPCALL_TIMEOUT);
978                                 schedule_timeout(timeout);
979                         }
980                         cache_get(&rsip->h);
981                         goto cache_check;
982                 }
983                 CWARN("waited %ds timeout, drop\n", GSS_SVC_UPCALL_TIMEOUT);
984                 break;
985         }
986         case -ENOENT:
987                 CDEBUG(D_SEC, "cache_check return ENOENT, drop\n");
988                 break;
989         case 0:
990                 /* if not the first check, we have to release the extra
991                  * reference we just added on it. */
992                 if (!first_check)
993                         cache_put(&rsip->h, &rsi_cache);
994                 CDEBUG(D_SEC, "cache_check is good\n");
995                 break;
996         }
997
998         remove_wait_queue(&rsip->waitq, &wait);
999         cache_put(&rsip->h, &rsi_cache);
1000
1001         if (rc)
1002                 GOTO(out, rc = SECSVC_DROP);
1003
1004         rc = SECSVC_DROP;
1005         rsci = gss_svc_searchbyctx(&rsip->out_handle);
1006         if (!rsci) {
1007                 CERROR("authentication failed\n");
1008
1009                 /* gss mechanism returned major and minor code so we return
1010                  * those in error message */
1011                 if (!gss_pack_err_notify(req, rsip->major_status,
1012                                          rsip->minor_status))
1013                         rc = SECSVC_COMPLETE;
1014
1015                 GOTO(out, rc);
1016         } else {
1017                 cache_get(&rsci->h);
1018                 grctx->src_ctx = &rsci->ctx;
1019         }
1020
1021         if (gw->gw_flags & LUSTRE_GSS_PACK_KCSUM) {
1022                 grctx->src_ctx->gsc_mechctx->hash_func = gss_digest_hash;
1023         } else if (!strcmp(grctx->src_ctx->gsc_mechctx->mech_type->gm_name,
1024                            "krb5") &&
1025                    !krb5_allow_old_client_csum) {
1026                 CWARN("%s: deny connection from '%s' due to missing 'krb_csum' feature, set 'sptlrpc.gss.krb5_allow_old_client_csum=1' to allow, but recommend client upgrade: rc = %d\n",
1027                       target->obd_name, libcfs_nidstr(&req->rq_peer.nid),
1028                       -EPROTO);
1029                 GOTO(out, rc = SECSVC_DROP);
1030         } else {
1031                 grctx->src_ctx->gsc_mechctx->hash_func =
1032                         gss_digest_hash_compat;
1033         }
1034
1035         if (rawobj_dup(&rsci->ctx.gsc_rvs_hdl, rvs_hdl)) {
1036                 CERROR("failed duplicate reverse handle\n");
1037                 GOTO(out, rc);
1038         }
1039
1040         rsci->target = target;
1041
1042         CDEBUG(D_SEC, "server create rsc %p(%u->%s)\n",
1043                rsci, rsci->ctx.gsc_uid, libcfs_nidstr(&req->rq_peer.nid));
1044
1045         if (rsip->out_handle.len > PTLRPC_GSS_MAX_HANDLE_SIZE) {
1046                 CERROR("handle size %u too large\n", rsip->out_handle.len);
1047                 GOTO(out, rc = SECSVC_DROP);
1048         }
1049
1050         grctx->src_init = 1;
1051         grctx->src_reserve_len = round_up(rsip->out_token.len, 4);
1052
1053         rc = lustre_pack_reply_v2(req, 1, &replen, NULL, 0);
1054         if (rc) {
1055                 CERROR("failed to pack reply: %d\n", rc);
1056                 GOTO(out, rc = SECSVC_DROP);
1057         }
1058
1059         rs = req->rq_reply_state;
1060         LASSERT(rs->rs_repbuf->lm_bufcount == 3);
1061         LASSERT(rs->rs_repbuf->lm_buflens[0] >=
1062                 sizeof(*rephdr) + rsip->out_handle.len);
1063         LASSERT(rs->rs_repbuf->lm_buflens[2] >= rsip->out_token.len);
1064
1065         rephdr = lustre_msg_buf(rs->rs_repbuf, 0, 0);
1066         rephdr->gh_version = PTLRPC_GSS_VERSION;
1067         rephdr->gh_flags = 0;
1068         rephdr->gh_proc = PTLRPC_GSS_PROC_ERR;
1069         rephdr->gh_major = rsip->major_status;
1070         rephdr->gh_minor = rsip->minor_status;
1071         rephdr->gh_seqwin = GSS_SEQ_WIN;
1072         rephdr->gh_handle.len = rsip->out_handle.len;
1073         memcpy(rephdr->gh_handle.data, rsip->out_handle.data,
1074                rsip->out_handle.len);
1075
1076         memcpy(lustre_msg_buf(rs->rs_repbuf, 2, 0), rsip->out_token.data,
1077                rsip->out_token.len);
1078
1079         rs->rs_repdata_len = lustre_shrink_msg(rs->rs_repbuf, 2,
1080                                                rsip->out_token.len, 0);
1081
1082         rc = SECSVC_OK;
1083
1084 out:
1085         /* it looks like here we should put rsip also, but this mess up
1086          * with NFS cache mgmt code... FIXME
1087          * something like:
1088          * if (rsip)
1089          *     rsi_put(&rsip->h, &rsi_cache); */
1090
1091         if (rsci) {
1092                 /* if anything went wrong, we don't keep the context too */
1093                 if (rc != SECSVC_OK)
1094                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1095                 else
1096                         CDEBUG(D_SEC, "create rsc with idx %#llx\n",
1097                                gss_handle_to_u64(&rsci->handle));
1098
1099                 COMPAT_RSC_PUT(&rsci->h, &rsc_cache);
1100         }
1101         RETURN(rc);
1102 }
1103
1104 struct gss_svc_ctx *gss_svc_upcall_get_ctx(struct ptlrpc_request *req,
1105                                            struct gss_wire_ctx *gw)
1106 {
1107         struct rsc *rsc;
1108
1109         rsc = gss_svc_searchbyctx(&gw->gw_handle);
1110         if (!rsc) {
1111                 CWARN("Invalid gss ctx idx %#llx from %s\n",
1112                       gss_handle_to_u64(&gw->gw_handle),
1113                       libcfs_nidstr(&req->rq_peer.nid));
1114                 return NULL;
1115         }
1116
1117         return &rsc->ctx;
1118 }
1119
1120 void gss_svc_upcall_put_ctx(struct gss_svc_ctx *ctx)
1121 {
1122         struct rsc *rsc = container_of(ctx, struct rsc, ctx);
1123
1124         COMPAT_RSC_PUT(&rsc->h, &rsc_cache);
1125 }
1126
1127 void gss_svc_upcall_destroy_ctx(struct gss_svc_ctx *ctx)
1128 {
1129         struct rsc *rsc = container_of(ctx, struct rsc, ctx);
1130
1131         /* can't be found */
1132         set_bit(CACHE_NEGATIVE, &rsc->h.flags);
1133         /* to be removed at next scan */
1134         rsc->h.expiry_time = 1;
1135 }
1136
1137 int __init gss_init_svc_upcall(void)
1138 {
1139         int     i, rc;
1140
1141         /*
1142          * this helps reducing context index confliction. after server reboot,
1143          * conflicting request from clients might be filtered out by initial
1144          * sequence number checking, thus no chance to sent error notification
1145          * back to clients.
1146          */
1147         get_random_bytes(&__ctx_index, sizeof(__ctx_index));
1148
1149 #ifdef HAVE_CACHE_HEAD_HLIST
1150         for (i = 0; i < rsi_cache.hash_size; i++)
1151                 INIT_HLIST_HEAD(&rsi_cache.hash_table[i]);
1152 #endif
1153         rc = cache_register_net(&rsi_cache, &init_net);
1154         if (rc != 0)
1155                 return rc;
1156
1157 #ifdef HAVE_CACHE_HEAD_HLIST
1158         for (i = 0; i < rsc_cache.hash_size; i++)
1159                 INIT_HLIST_HEAD(&rsc_cache.hash_table[i]);
1160 #endif
1161         rc = cache_register_net(&rsc_cache, &init_net);
1162         if (rc != 0) {
1163                 cache_unregister_net(&rsi_cache, &init_net);
1164                 return rc;
1165         }
1166
1167         /* FIXME this looks stupid. we intend to give lsvcgssd a chance to open
1168          * the init upcall channel, otherwise there's big chance that the first
1169          * upcall issued before the channel be opened thus nfsv4 cache code will
1170          * drop the request directly, thus lead to unnecessary recovery time.
1171          * Here we wait at minimum 1.5 seconds.
1172          */
1173         for (i = 0; i < 6; i++) {
1174                 if (channel_users(&rsi_cache) > 0)
1175                         break;
1176                 schedule_timeout_uninterruptible(cfs_time_seconds(1) / 4);
1177         }
1178
1179         if (channel_users(&rsi_cache) == 0)
1180                 CDEBUG(D_SEC,
1181                        "Init channel is not opened by lsvcgssd, following request might be dropped until lsvcgssd is active\n");
1182
1183         return 0;
1184 }
1185
1186 void gss_exit_svc_upcall(void)
1187 {
1188         cache_purge(&rsi_cache);
1189         cache_unregister_net(&rsi_cache, &init_net);
1190
1191         cache_purge(&rsc_cache);
1192         cache_unregister_net(&rsc_cache, &init_net);
1193 }