Whamcloud - gitweb
LU-17015 gss: remove legacy sunrpc-cache based gss caches
[fs/lustre-release.git] / lustre / ptlrpc / gss / gss_svc_upcall.c
1 /*
2  * Modifications for Lustre
3  *
4  * Copyright (c) 2007, 2010, Oracle and/or its affiliates. All rights reserved.
5  *
6  * Copyright (c) 2012, 2014, Intel Corporation.
7  *
8  * Author: Eric Mei <ericm@clusterfs.com>
9  */
10
11 /*
12  * Neil Brown <neilb@cse.unsw.edu.au>
13  * J. Bruce Fields <bfields@umich.edu>
14  * Andy Adamson <andros@umich.edu>
15  * Dug Song <dugsong@monkey.org>
16  *
17  * RPCSEC_GSS server authentication.
18  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
19  * (gssapi)
20  *
21  * The RPCSEC_GSS involves three stages:
22  *  1/ context creation
23  *  2/ data exchange
24  *  3/ context destruction
25  *
26  * Context creation is handled largely by upcalls to user-space.
27  *  In particular, GSS_Accept_sec_context is handled by an upcall
28  * Data exchange is handled entirely within the kernel
29  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
30  * Context destruction is handled in-kernel
31  *  GSS_Delete_sec_context is in-kernel
32  *
33  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
34  * The context handle and gss_token are used as a key into the rpcsec_init cache.
35  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
36  * being major_status, minor_status, context_handle, reply_token.
37  * These are sent back to the client.
38  * Sequence window management is handled by the kernel.  The window size if currently
39  * a compile time constant.
40  *
41  * When user-space is happy that a context is established, it places an entry
42  * in the rpcsec_context cache. The key for this cache is the context_handle.
43  * The content includes:
44  *   uid/gidlist - for determining access rights
45  *   mechanism type
46  *   mechanism specific information, such as a key
47  *
48  */
49
50 #define DEBUG_SUBSYSTEM S_SEC
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/random.h>
55 #include <linux/slab.h>
56 #include <linux/mutex.h>
57 #include <linux/binfmts.h>
58 #include <net/sock.h>
59 #include <linux/un.h>
60
61 #include <obd.h>
62 #include <obd_class.h>
63 #include <obd_support.h>
64 #include <lustre_import.h>
65 #include <lustre_net.h>
66 #include <lustre_nodemap.h>
67 #include <lustre_sec.h>
68 #include <libcfs/linux/linux-hash.h>
69
70 #include "gss_err.h"
71 #include "gss_internal.h"
72 #include "gss_api.h"
73 #include "gss_crypto.h"
74
75 static DEFINE_SPINLOCK(__ctx_index_lock);
76 static __u64 __ctx_index;
77
78 unsigned int krb5_allow_old_client_csum;
79
80 __u64 gss_get_next_ctx_index(void)
81 {
82         __u64 idx;
83
84         spin_lock(&__ctx_index_lock);
85         idx = __ctx_index++;
86         spin_unlock(&__ctx_index_lock);
87
88         return idx;
89 }
90
91 static inline unsigned long hash_mem(char *buf, int length, int bits)
92 {
93         unsigned long hash = 0;
94         unsigned long l = 0;
95         int len = 0;
96         unsigned char c;
97
98         do {
99                 if (len == length) {
100                         c = (char) len;
101                         len = -1;
102                 } else
103                         c = *buf++;
104
105                 l = (l << 8) | c;
106                 len++;
107
108                 if ((len & (BITS_PER_LONG/8-1)) == 0)
109                         hash = cfs_hash_long(hash^l, BITS_PER_LONG);
110         } while (len);
111
112         return hash >> (BITS_PER_LONG - bits);
113 }
114
115 /* This is a little bit of a concern but we need to make our own hash64 function
116  * as the one from the kernel seems to be buggy by returning a u32:
117  * static __always_inline u32 hash_64_generic(u64 val, unsigned int bits)
118  */
119 #if BITS_PER_LONG == 64
120 static __always_inline __u64 gss_hash_64(__u64 val, unsigned int bits)
121 {
122         __u64 hash = val;
123         /*  Sigh, gcc can't optimise this alone like it does for 32 bits. */
124         __u64 n = hash;
125
126         n <<= 18;
127         hash -= n;
128         n <<= 33;
129         hash -= n;
130         n <<= 3;
131         hash += n;
132         n <<= 3;
133         hash -= n;
134         n <<= 4;
135         hash += n;
136         n <<= 2;
137         hash += n;
138
139         /* High bits are more random, so use them. */
140         return hash >> (64 - bits);
141 }
142
143 static inline unsigned long hash_mem_64(char *buf, int length, int bits)
144 {
145         unsigned long hash = 0;
146         unsigned long l = 0;
147         int len = 0;
148         unsigned char c;
149
150         do {
151                 if (len == length) {
152                         c = (char) len;
153                         len = -1;
154                 } else
155                         c = *buf++;
156
157                 l = (l << 8) | c;
158                 len++;
159
160                 if ((len & (BITS_PER_LONG/8-1)) == 0)
161                         hash = gss_hash_64(hash^l, BITS_PER_LONG);
162         } while (len);
163
164         return hash >> (BITS_PER_LONG - bits);
165 }
166 #endif /* BITS_PER_LONG == 64 */
167
168 /****************************************
169  * rpc sec init (rsi) cache             *
170  ****************************************/
171
172 #define RSI_HASHBITS    (6)
173
174 static void rsi_entry_init(struct upcall_cache_entry *entry,
175                            void *args)
176 {
177         struct gss_rsi *rsi = &entry->u.rsi;
178         struct gss_rsi *tmp = args;
179
180         rsi->si_uc_entry = entry;
181         rawobj_dup(&rsi->si_in_handle, &tmp->si_in_handle);
182         rawobj_dup(&rsi->si_in_token, &tmp->si_in_token);
183         rsi->si_out_handle = RAWOBJ_EMPTY;
184         rsi->si_out_token = RAWOBJ_EMPTY;
185
186         rsi->si_lustre_svc = tmp->si_lustre_svc;
187         rsi->si_nid4 = tmp->si_nid4;
188         memcpy(rsi->si_nm_name, tmp->si_nm_name, sizeof(tmp->si_nm_name));
189 }
190
191 static void __rsi_free(struct gss_rsi *rsi)
192 {
193         rawobj_free(&rsi->si_in_handle);
194         rawobj_free(&rsi->si_in_token);
195         rawobj_free(&rsi->si_out_handle);
196         rawobj_free(&rsi->si_out_token);
197 }
198
199 static void rsi_entry_free(struct upcall_cache *cache,
200                            struct upcall_cache_entry *entry)
201 {
202         struct gss_rsi *rsi = &entry->u.rsi;
203
204         __rsi_free(rsi);
205 }
206
207 static inline int rsi_entry_hash(struct gss_rsi *rsi)
208 {
209 #if BITS_PER_LONG == 64
210         return hash_mem_64((char *)rsi->si_in_handle.data,
211                            rsi->si_in_handle.len, RSI_HASHBITS) ^
212                 hash_mem_64((char *)rsi->si_in_token.data,
213                             rsi->si_in_token.len, RSI_HASHBITS);
214 #else
215         return hash_mem((char *)rsi->si_in_handle.data, rsi->si_in_handle.len,
216                         RSI_HASHBITS) ^
217                 hash_mem((char *)rsi->si_in_token.data, rsi->si_in_token.len,
218                          RSI_HASHBITS);
219 #endif
220 }
221
222 static inline int __rsi_entry_match(rawobj_t *h1, rawobj_t *h2,
223                                     rawobj_t *t1, rawobj_t *t2)
224 {
225         return !(rawobj_equal(h1, h2) && rawobj_equal(t1, t2));
226 }
227
228 static inline int rsi_entry_match(struct gss_rsi *rsi, struct gss_rsi *tmp)
229 {
230         return __rsi_entry_match(&rsi->si_in_handle, &tmp->si_in_handle,
231                                  &rsi->si_in_token, &tmp->si_in_token);
232 }
233
234 /* Returns 0 to tell this is a match */
235 static inline int rsi_upcall_compare(struct upcall_cache *cache,
236                                      struct upcall_cache_entry *entry,
237                                      __u64 key, void *args)
238 {
239         struct gss_rsi *rsi1 = &entry->u.rsi;
240         struct gss_rsi *rsi2 = args;
241
242         return rsi_entry_match(rsi1, rsi2);
243 }
244
245 /* See handle_channel_request() userspace for where the upcall data is read */
246 static int rsi_do_upcall(struct upcall_cache *cache,
247                          struct upcall_cache_entry *entry)
248 {
249         int size, len, *blen;
250         char *buffer, *bp, **bpp;
251         char *argv[] = {
252                 [0] = cache->uc_upcall,
253                 [1] = "-c",
254                 [2] = cache->uc_name,
255                 [3] = "-r",
256                 [4] = NULL,
257                 [5] = NULL
258         };
259         char *envp[] = {
260                 [0] = "HOME=/",
261                 [1] = "PATH=/sbin:/usr/sbin",
262                 [2] = NULL
263         };
264         ktime_t start, end;
265         struct gss_rsi *rsi = &entry->u.rsi;
266         __u64 index = 0;
267         int rc;
268
269         ENTRY;
270         CDEBUG(D_SEC, "rsi upcall '%s' on '%s'\n",
271                cache->uc_upcall, cache->uc_name);
272
273         size = 24 + 1 + /* ue_key is uint64_t */
274                 12 + 1 + /* si_lustre_svc is __u32*/
275                 18 + 1 + /* si_nid4 is lnet_nid_t, hex with leading 0x */
276                 18 + 1 + /* index is __u64, hex with leading 0x */
277                 strlen(rsi->si_nm_name) + 1 +
278                 BASE64URL_CHARS(rsi->si_in_handle.len) + 1 +
279                 BASE64URL_CHARS(rsi->si_in_token.len) + 1 +
280                 1 + 1; /* eol */
281         if (size > MAX_ARG_STRLEN)
282                 RETURN(-E2BIG);
283         OBD_ALLOC_LARGE(buffer, size);
284         if (!buffer)
285                 RETURN(-ENOMEM);
286
287         bp = buffer;
288         bpp = &bp;
289         len = size;
290         blen = &len;
291
292         /* if in_handle is null, provide kernel suggestion */
293         if (rsi->si_in_handle.len == 0)
294                 index = gss_get_next_ctx_index();
295
296         /* entry->ue_key is put into args sent via upcall, so that it can be
297          * returned by userspace. This will help find cache entry at downcall,
298          * without unnecessary recomputation of the hash.
299          */
300         gss_u64_write_string(bpp, blen, entry->ue_key);
301         gss_u64_write_string(bpp, blen, rsi->si_lustre_svc);
302         gss_u64_write_hex_string(bpp, blen, rsi->si_nid4);
303         gss_u64_write_hex_string(bpp, blen, index);
304         gss_string_write(bpp, blen, (char *) rsi->si_nm_name);
305         gss_base64url_encode(bpp, blen, rsi->si_in_handle.data,
306                              rsi->si_in_handle.len);
307         gss_base64url_encode(bpp, blen, rsi->si_in_token.data,
308                              rsi->si_in_token.len);
309         (*bpp)[-1] = '\n';
310         (*bpp)[0] = '\0';
311
312         argv[4] = buffer;
313         down_read(&cache->uc_upcall_rwsem);
314         start = ktime_get();
315         rc = call_usermodehelper(argv[0], argv, envp, UMH_WAIT_EXEC);
316         end = ktime_get();
317         up_read(&cache->uc_upcall_rwsem);
318         if (rc < 0) {
319                 CERROR("%s: error invoking upcall %s %s (time %ldus): rc = %d\n",
320                        cache->uc_name, argv[0], argv[2],
321                        (long)ktime_us_delta(end, start), rc);
322         } else {
323                 CDEBUG(D_SEC, "%s: invoked upcall %s %s (time %ldus)\n",
324                        cache->uc_name, argv[0], argv[2],
325                        (long)ktime_us_delta(end, start));
326                 rc = 0;
327         }
328
329         OBD_FREE_LARGE(buffer, size);
330         RETURN(rc);
331 }
332
333 static inline int rsi_downcall_compare(struct upcall_cache *cache,
334                                        struct upcall_cache_entry *entry,
335                                        __u64 key, void *args)
336 {
337         struct gss_rsi *rsi = &entry->u.rsi;
338         struct rsi_downcall_data *sid = args;
339         char *mesg = sid->sid_val;
340         rawobj_t handle, token;
341         char *p = mesg;
342         int len;
343
344         /* sid_val starts with handle and token */
345
346         /* First, handle */
347         len = gss_buffer_get(&mesg, &handle.len, &handle.data);
348         sid->sid_offset = mesg - p;
349         p = mesg;
350
351         /* Second, token */
352         len = gss_buffer_get(&mesg, &token.len, &token.data);
353         sid->sid_offset += mesg - p;
354
355         return __rsi_entry_match(&rsi->si_in_handle, &handle,
356                                  &rsi->si_in_token, &token);
357 }
358
359 static int rsi_parse_downcall(struct upcall_cache *cache,
360                               struct upcall_cache_entry *entry,
361                               void *args)
362 {
363         struct gss_rsi *rsi = &entry->u.rsi;
364         struct rsi_downcall_data *sid = args;
365         int mlen = sid->sid_len;
366         char *mesg = sid->sid_val + sid->sid_offset;
367         char *buf = sid->sid_val;
368         int status = -EINVAL;
369         int len;
370
371         ENTRY;
372
373         if (mlen <= 0)
374                 goto out;
375
376         rsi->si_major_status = sid->sid_maj_stat;
377         rsi->si_minor_status = sid->sid_min_stat;
378
379         /* in_handle and in_token have already been consumed in
380          * rsi_downcall_compare(). sid_offset gives next field.
381          */
382
383         /* out_handle */
384         len = gss_buffer_read(&mesg, buf, mlen);
385         if (len < 0)
386                 goto out;
387         if (rawobj_alloc(&rsi->si_out_handle, buf, len)) {
388                 status = -ENOMEM;
389                 goto out;
390         }
391
392         /* out_token */
393         len = gss_buffer_read(&mesg, buf, mlen);
394         if (len < 0)
395                 goto out;
396         if (rawobj_alloc(&rsi->si_out_token, buf, len)) {
397                 status = -ENOMEM;
398                 goto out;
399         }
400
401         entry->ue_expire = 0;
402         status = 0;
403
404 out:
405         CDEBUG(D_OTHER, "rsi parse %p: %d\n", rsi, status);
406         RETURN(status);
407 }
408
409 struct gss_rsi *rsi_entry_get(struct upcall_cache *cache, struct gss_rsi *rsi)
410 {
411         struct upcall_cache_entry *entry;
412         int hash = rsi_entry_hash(rsi);
413
414         if (!cache)
415                 return ERR_PTR(-ENOENT);
416
417         entry = upcall_cache_get_entry(cache, (__u64)hash, rsi);
418         if (unlikely(!entry))
419                 return ERR_PTR(-ENOENT);
420         if (IS_ERR(entry))
421                 return ERR_CAST(entry);
422
423         return &entry->u.rsi;
424 }
425
426 void rsi_entry_put(struct upcall_cache *cache, struct gss_rsi *rsi)
427 {
428         if (!cache || !rsi)
429                 return;
430
431         upcall_cache_put_entry(cache, rsi->si_uc_entry);
432 }
433
434 void rsi_flush(struct upcall_cache *cache, int hash)
435 {
436         if (hash < 0)
437                 upcall_cache_flush_idle(cache);
438         else
439                 upcall_cache_flush_one(cache, (__u64)hash, NULL);
440 }
441
442 struct upcall_cache_ops rsi_upcall_cache_ops = {
443         .init_entry       = rsi_entry_init,
444         .free_entry       = rsi_entry_free,
445         .upcall_compare   = rsi_upcall_compare,
446         .downcall_compare = rsi_downcall_compare,
447         .do_upcall        = rsi_do_upcall,
448         .parse_downcall   = rsi_parse_downcall,
449 };
450
451 struct upcall_cache *rsicache;
452
453 /****************************************
454  * rpc sec context (rsc) cache          *
455  ****************************************/
456
457 #define RSC_HASHBITS    (10)
458
459 static void rsc_entry_init(struct upcall_cache_entry *entry,
460                            void *args)
461 {
462         struct gss_rsc *rsc = &entry->u.rsc;
463         struct gss_rsc *tmp = args;
464
465         rsc->sc_uc_entry = entry;
466         rawobj_dup(&rsc->sc_handle, &tmp->sc_handle);
467
468         rsc->sc_target = NULL;
469         memset(&rsc->sc_ctx, 0, sizeof(rsc->sc_ctx));
470         rsc->sc_ctx.gsc_rvs_hdl = RAWOBJ_EMPTY;
471
472         memset(&rsc->sc_ctx.gsc_seqdata, 0, sizeof(rsc->sc_ctx.gsc_seqdata));
473         spin_lock_init(&rsc->sc_ctx.gsc_seqdata.ssd_lock);
474 }
475
476 void __rsc_free(struct gss_rsc *rsc)
477 {
478         rawobj_free(&rsc->sc_handle);
479         rawobj_free(&rsc->sc_ctx.gsc_rvs_hdl);
480         lgss_delete_sec_context(&rsc->sc_ctx.gsc_mechctx);
481 }
482
483 static void rsc_entry_free(struct upcall_cache *cache,
484                            struct upcall_cache_entry *entry)
485 {
486         struct gss_rsc *rsc = &entry->u.rsc;
487
488         __rsc_free(rsc);
489 }
490
491 static inline int rsc_entry_hash(struct gss_rsc *rsc)
492 {
493 #if BITS_PER_LONG == 64
494         return hash_mem_64((char *)rsc->sc_handle.data,
495                            rsc->sc_handle.len, RSC_HASHBITS);
496 #else
497         return hash_mem((char *)rsc->sc_handle.data,
498                         rsc->sc_handle.len, RSC_HASHBITS);
499 #endif
500 }
501
502 static inline int __rsc_entry_match(rawobj_t *h1, rawobj_t *h2)
503 {
504         return !(rawobj_equal(h1, h2));
505 }
506
507 static inline int rsc_entry_match(struct gss_rsc *rsc, struct gss_rsc *tmp)
508 {
509         return __rsc_entry_match(&rsc->sc_handle, &tmp->sc_handle);
510 }
511
512 /* Returns 0 to tell this is a match */
513 static inline int rsc_upcall_compare(struct upcall_cache *cache,
514                                      struct upcall_cache_entry *entry,
515                                      __u64 key, void *args)
516 {
517         struct gss_rsc *rsc1 = &entry->u.rsc;
518         struct gss_rsc *rsc2 = args;
519
520         return rsc_entry_match(rsc1, rsc2);
521 }
522
523 /* rsc upcall is a no-op, we just need a valid entry */
524 static inline int rsc_do_upcall(struct upcall_cache *cache,
525                                 struct upcall_cache_entry *entry)
526 {
527         upcall_cache_update_entry(cache, entry,
528                                   ktime_get_seconds() + cache->uc_entry_expire,
529                                   0);
530         return 0;
531 }
532
533 static inline int rsc_downcall_compare(struct upcall_cache *cache,
534                                        struct upcall_cache_entry *entry,
535                                        __u64 key, void *args)
536 {
537         struct gss_rsc *rsc = &entry->u.rsc;
538         struct rsc_downcall_data *scd = args;
539         char *mesg = scd->scd_val;
540         rawobj_t handle;
541         int len;
542
543         /* scd_val starts with handle */
544         len = gss_buffer_get(&mesg, &handle.len, &handle.data);
545         scd->scd_offset = mesg - scd->scd_val;
546
547         return __rsc_entry_match(&rsc->sc_handle, &handle);
548 }
549
550 static int rsc_parse_downcall(struct upcall_cache *cache,
551                               struct upcall_cache_entry *entry,
552                               void *args)
553 {
554         struct gss_api_mech *gm = NULL;
555         struct gss_rsc *rsc = &entry->u.rsc;
556         struct rsc_downcall_data *scd = args;
557         int mlen = scd->scd_len;
558         char *mesg = scd->scd_val + scd->scd_offset;
559         char *buf = scd->scd_val;
560         int status = -EINVAL;
561         time64_t ctx_expiry;
562         rawobj_t tmp_buf;
563         int len;
564
565         ENTRY;
566
567         if (mlen <= 0)
568                 goto out;
569
570         rsc->sc_ctx.gsc_remote = !!(scd->scd_flags & RSC_DATA_FLAG_REMOTE);
571         rsc->sc_ctx.gsc_usr_root = !!(scd->scd_flags & RSC_DATA_FLAG_ROOT);
572         rsc->sc_ctx.gsc_usr_mds = !!(scd->scd_flags & RSC_DATA_FLAG_MDS);
573         rsc->sc_ctx.gsc_usr_oss = !!(scd->scd_flags & RSC_DATA_FLAG_OSS);
574         rsc->sc_ctx.gsc_mapped_uid = scd->scd_mapped_uid;
575         rsc->sc_ctx.gsc_uid = scd->scd_uid;
576
577         rsc->sc_ctx.gsc_gid = scd->scd_gid;
578         gm = lgss_name_to_mech(scd->scd_mechname);
579         if (!gm) {
580                 status = -EOPNOTSUPP;
581                 goto out;
582         }
583
584         /* handle has already been consumed in rsc_downcall_compare().
585          * scd_offset gives next field.
586          */
587
588         /* context token */
589         len = gss_buffer_read(&mesg, buf, mlen);
590         if (len < 0)
591                 goto out;
592         tmp_buf.len = len;
593         tmp_buf.data = (unsigned char *)buf;
594         if (lgss_import_sec_context(&tmp_buf, gm,
595                                     &rsc->sc_ctx.gsc_mechctx))
596                 goto out;
597
598         if (lgss_inquire_context(rsc->sc_ctx.gsc_mechctx, &ctx_expiry))
599                 goto out;
600
601         /* ctx_expiry is the number of seconds since Jan 1 1970.
602          * We just want the number of seconds into the future.
603          */
604         entry->ue_expire = ktime_get_seconds() +
605                 (ctx_expiry - ktime_get_real_seconds());
606         status = 0;
607
608 out:
609         if (gm)
610                 lgss_mech_put(gm);
611         CDEBUG(D_OTHER, "rsc parse %p: %d\n", rsc, status);
612         RETURN(status);
613 }
614
615 struct gss_rsc *rsc_entry_get(struct upcall_cache *cache, struct gss_rsc *rsc)
616 {
617         struct upcall_cache_entry *entry;
618         int hash = rsc_entry_hash(rsc);
619
620         if (!cache)
621                 return ERR_PTR(-ENOENT);
622
623         entry = upcall_cache_get_entry(cache, (__u64)hash, rsc);
624         if (unlikely(!entry))
625                 return ERR_PTR(-ENOENT);
626         if (IS_ERR(entry))
627                 return ERR_CAST(entry);
628
629         return &entry->u.rsc;
630 }
631
632 void rsc_entry_put(struct upcall_cache *cache, struct gss_rsc *rsc)
633 {
634         if (!cache || !rsc)
635                 return;
636
637         upcall_cache_put_entry(cache, rsc->sc_uc_entry);
638 }
639
640 void rsc_flush(struct upcall_cache *cache, int hash)
641 {
642         if (hash < 0)
643                 upcall_cache_flush_idle(cache);
644         else
645                 upcall_cache_flush_one(cache, (__u64)hash, NULL);
646 }
647
648 struct upcall_cache_ops rsc_upcall_cache_ops = {
649         .init_entry       = rsc_entry_init,
650         .free_entry       = rsc_entry_free,
651         .upcall_compare   = rsc_upcall_compare,
652         .downcall_compare = rsc_downcall_compare,
653         .do_upcall        = rsc_do_upcall,
654         .parse_downcall   = rsc_parse_downcall,
655 };
656
657 struct upcall_cache *rsccache;
658
659 /****************************************
660  * rsc cache flush                      *
661  ****************************************/
662
663 static struct gss_rsc *gss_svc_searchbyctx(rawobj_t *handle)
664 {
665         struct gss_rsc rsc;
666         struct gss_rsc *found;
667
668         memset(&rsc, 0, sizeof(rsc));
669         if (rawobj_dup(&rsc.sc_handle, handle))
670                 return NULL;
671
672         found = rsc_entry_get(rsccache, &rsc);
673         __rsc_free(&rsc);
674         if (IS_ERR_OR_NULL(found))
675                 return found;
676         if (!found->sc_ctx.gsc_mechctx) {
677                 rsc_entry_put(rsccache, found);
678                 return ERR_PTR(-ENOENT);
679         }
680         return found;
681 }
682
683 int gss_svc_upcall_install_rvs_ctx(struct obd_import *imp,
684                                    struct gss_sec *gsec,
685                                    struct gss_cli_ctx *gctx)
686 {
687         struct gss_rsc rsc, *rscp = NULL;
688         time64_t ctx_expiry;
689         __u32 major;
690         int rc;
691
692         ENTRY;
693         memset(&rsc, 0, sizeof(rsc));
694
695         if (!imp || !imp->imp_obd) {
696                 CERROR("invalid imp, drop\n");
697                 RETURN(-EPROTO);
698         }
699
700         if (rawobj_alloc(&rsc.sc_handle, (char *)&gsec->gs_rvs_hdl,
701                          sizeof(gsec->gs_rvs_hdl)))
702                 GOTO(out, rc = -ENOMEM);
703
704         rscp = rsc_entry_get(rsccache, &rsc);
705         __rsc_free(&rsc);
706         if (IS_ERR_OR_NULL(rscp))
707                 GOTO(out, rc = -ENOMEM);
708
709         major = lgss_copy_reverse_context(gctx->gc_mechctx,
710                                           &rscp->sc_ctx.gsc_mechctx);
711         if (major != GSS_S_COMPLETE)
712                 GOTO(out, rc = -ENOMEM);
713
714         if (lgss_inquire_context(rscp->sc_ctx.gsc_mechctx, &ctx_expiry)) {
715                 CERROR("%s: unable to get expire time, drop\n",
716                        imp->imp_obd->obd_name);
717                 GOTO(out, rc = -EINVAL);
718         }
719         rscp->sc_uc_entry->ue_expire = ktime_get_seconds() +
720                 (ctx_expiry - ktime_get_real_seconds());
721
722         switch (imp->imp_obd->u.cli.cl_sp_to) {
723         case LUSTRE_SP_MDT:
724                 rscp->sc_ctx.gsc_usr_mds = 1;
725                 break;
726         case LUSTRE_SP_OST:
727                 rscp->sc_ctx.gsc_usr_oss = 1;
728                 break;
729         case LUSTRE_SP_CLI:
730                 rscp->sc_ctx.gsc_usr_root = 1;
731                 break;
732         case LUSTRE_SP_MGS:
733                 /* by convention, all 3 set to 1 means MGS */
734                 rscp->sc_ctx.gsc_usr_mds = 1;
735                 rscp->sc_ctx.gsc_usr_oss = 1;
736                 rscp->sc_ctx.gsc_usr_root = 1;
737                 break;
738         default:
739                 break;
740         }
741
742         rscp->sc_target = imp->imp_obd;
743         rawobj_dup(&gctx->gc_svc_handle, &rscp->sc_handle);
744
745         CDEBUG(D_SEC, "%s: create reverse svc ctx %p to %s: idx %#llx\n",
746                imp->imp_obd->obd_name, &rscp->sc_ctx, obd2cli_tgt(imp->imp_obd),
747                gsec->gs_rvs_hdl);
748         rc = 0;
749 out:
750         if (!IS_ERR_OR_NULL(rscp))
751                 rsc_entry_put(rsccache, rscp);
752         if (rc)
753                 CERROR("%s: can't create reverse svc ctx idx %#llx: rc = %d\n",
754                        imp->imp_obd->obd_name, gsec->gs_rvs_hdl, rc);
755         RETURN(rc);
756 }
757
758 int gss_svc_upcall_expire_rvs_ctx(rawobj_t *handle)
759 {
760         const time64_t expire = 20;
761         struct gss_rsc *rscp;
762
763         rscp = gss_svc_searchbyctx(handle);
764         if (!IS_ERR_OR_NULL(rscp)) {
765                 CDEBUG(D_SEC,
766                        "reverse svcctx %p (rsc %p) expire in %lld seconds\n",
767                        &rscp->sc_ctx, rscp, expire);
768
769                 rscp->sc_uc_entry->ue_expire = ktime_get_seconds() + expire;
770                 rsc_entry_put(rsccache, rscp);
771         }
772         return 0;
773 }
774
775 int gss_svc_upcall_dup_handle(rawobj_t *handle, struct gss_svc_ctx *ctx)
776 {
777         struct gss_rsc *rscp = container_of(ctx, struct gss_rsc, sc_ctx);
778
779         return rawobj_dup(handle, &rscp->sc_handle);
780 }
781
782 int gss_svc_upcall_update_sequence(rawobj_t *handle, __u32 seq)
783 {
784         struct gss_rsc *rscp;
785
786         rscp = gss_svc_searchbyctx(handle);
787         if (!IS_ERR_OR_NULL(rscp)) {
788                 CDEBUG(D_SEC, "reverse svcctx %p (rsc %p) update seq to %u\n",
789                        &rscp->sc_ctx, rscp, seq + 1);
790
791                 rscp->sc_ctx.gsc_rvs_seq = seq + 1;
792                 rsc_entry_put(rsccache, rscp);
793         }
794         return 0;
795 }
796
797 int gss_svc_upcall_handle_init(struct ptlrpc_request *req,
798                                struct gss_svc_reqctx *grctx,
799                                struct gss_wire_ctx *gw,
800                                struct obd_device *target,
801                                __u32 lustre_svc,
802                                rawobj_t *rvs_hdl,
803                                rawobj_t *in_token)
804 {
805         struct gss_rsi rsi = { 0 }, *rsip = NULL;
806         struct ptlrpc_reply_state *rs;
807         struct gss_rsc *rscp = NULL;
808         int replen = sizeof(struct ptlrpc_body);
809         struct gss_rep_header *rephdr;
810         int rc, rc2;
811
812         ENTRY;
813
814         rsi.si_lustre_svc = lustre_svc;
815         /* In case of MR, rq_peer is not the NID from which request is received,
816          * but primary NID of peer.
817          * So we need LNetPrimaryNID(rq_source) to match what the clients uses.
818          */
819         LNetPrimaryNID(&req->rq_source.nid);
820         rsi.si_nid4 = lnet_nid_to_nid4(&req->rq_source.nid);
821         nodemap_test_nid(lnet_nid_to_nid4(&req->rq_peer.nid), rsi.si_nm_name,
822                          sizeof(rsi.si_nm_name));
823
824         /* Note that context handle is always 0 for for INIT. */
825         rc2 = rawobj_dup(&rsi.si_in_handle, &gw->gw_handle);
826         if (rc2) {
827                 CERROR("%s: failed to duplicate context handle: rc = %d\n",
828                        target->obd_name, rc2);
829                 GOTO(out, rc = SECSVC_DROP);
830         }
831
832         rc2 = rawobj_dup(&rsi.si_in_token, in_token);
833         if (rc2) {
834                 CERROR("%s: failed to duplicate token: rc = %d\n",
835                        target->obd_name, rc2);
836                 rawobj_free(&rsi.si_in_handle);
837                 GOTO(out, rc = SECSVC_DROP);
838         }
839
840         rsip = rsi_entry_get(rsicache, &rsi);
841         __rsi_free(&rsi);
842         if (IS_ERR_OR_NULL(rsip)) {
843                 if (IS_ERR(rsip))
844                         rc2 = PTR_ERR(rsip);
845                 else
846                         rc2 = -EINVAL;
847                 CERROR("%s: failed to get entry from rsi cache (nid %s): rc = %d\n",
848                        target->obd_name,
849                        libcfs_nid2str(lnet_nid_to_nid4(&req->rq_source.nid)),
850                        rc2);
851
852                 if (!gss_pack_err_notify(req, GSS_S_FAILURE, 0))
853                         rc = SECSVC_COMPLETE;
854                 else
855                         rc = SECSVC_DROP;
856
857                 GOTO(out, rc);
858         }
859
860         rscp = gss_svc_searchbyctx(&rsip->si_out_handle);
861         if (IS_ERR_OR_NULL(rscp)) {
862                 /* gss mechanism returned major and minor code so we return
863                  * those in error message */
864
865                 if (!gss_pack_err_notify(req, rsip->si_major_status,
866                                          rsip->si_minor_status))
867                         rc = SECSVC_COMPLETE;
868                 else
869                         rc = SECSVC_DROP;
870
871                 CERROR("%s: authentication failed: rc = %d\n",
872                        target->obd_name, rc);
873                 GOTO(out, rc);
874         } else {
875                 /* we need to take an extra ref on the cache entry,
876                  * as a pointer to sc_ctx is stored in grctx
877                  */
878                 upcall_cache_get_entry_raw(rscp->sc_uc_entry);
879                 grctx->src_ctx = &rscp->sc_ctx;
880         }
881
882         if (gw->gw_flags & LUSTRE_GSS_PACK_KCSUM) {
883                 grctx->src_ctx->gsc_mechctx->hash_func = gss_digest_hash;
884         } else if (!strcmp(grctx->src_ctx->gsc_mechctx->mech_type->gm_name,
885                            "krb5") &&
886                    !krb5_allow_old_client_csum) {
887                 CWARN("%s: deny connection from '%s' due to missing 'krb_csum' feature, set 'sptlrpc.gss.krb5_allow_old_client_csum=1' to allow, but recommend client upgrade: rc = %d\n",
888                       target->obd_name, libcfs_nidstr(&req->rq_peer.nid),
889                       -EPROTO);
890                 GOTO(out, rc = SECSVC_DROP);
891         } else {
892                 grctx->src_ctx->gsc_mechctx->hash_func =
893                         gss_digest_hash_compat;
894         }
895
896         if (rawobj_dup(&rscp->sc_ctx.gsc_rvs_hdl, rvs_hdl)) {
897                 CERROR("%s: failed duplicate reverse handle\n",
898                        target->obd_name);
899                 GOTO(out, rc = SECSVC_DROP);
900         }
901
902         rscp->sc_target = target;
903
904         CDEBUG(D_SEC, "%s: server create rsc %p(%u->%s)\n",
905                target->obd_name, rscp, rscp->sc_ctx.gsc_uid,
906                libcfs_nidstr(&req->rq_peer.nid));
907
908         if (rsip->si_out_handle.len > PTLRPC_GSS_MAX_HANDLE_SIZE) {
909                 CERROR("%s: handle size %u too large\n",
910                        target->obd_name, rsip->si_out_handle.len);
911                 GOTO(out, rc = SECSVC_DROP);
912         }
913
914         grctx->src_init = 1;
915         grctx->src_reserve_len = round_up(rsip->si_out_token.len, 4);
916
917         rc = lustre_pack_reply_v2(req, 1, &replen, NULL, 0);
918         if (rc) {
919                 CERROR("%s: failed to pack reply: rc = %d\n",
920                        target->obd_name, rc);
921                 GOTO(out, rc = SECSVC_DROP);
922         }
923
924         rs = req->rq_reply_state;
925         LASSERT(rs->rs_repbuf->lm_bufcount == 3);
926         LASSERT(rs->rs_repbuf->lm_buflens[0] >=
927                 sizeof(*rephdr) + rsip->si_out_handle.len);
928         LASSERT(rs->rs_repbuf->lm_buflens[2] >= rsip->si_out_token.len);
929
930         rephdr = lustre_msg_buf(rs->rs_repbuf, 0, 0);
931         rephdr->gh_version = PTLRPC_GSS_VERSION;
932         rephdr->gh_flags = 0;
933         rephdr->gh_proc = PTLRPC_GSS_PROC_ERR;
934         rephdr->gh_major = rsip->si_major_status;
935         rephdr->gh_minor = rsip->si_minor_status;
936         rephdr->gh_seqwin = GSS_SEQ_WIN;
937         rephdr->gh_handle.len = rsip->si_out_handle.len;
938         memcpy(rephdr->gh_handle.data, rsip->si_out_handle.data,
939                rsip->si_out_handle.len);
940
941         memcpy(lustre_msg_buf(rs->rs_repbuf, 2, 0), rsip->si_out_token.data,
942                rsip->si_out_token.len);
943
944         rs->rs_repdata_len = lustre_shrink_msg(rs->rs_repbuf, 2,
945                                                rsip->si_out_token.len, 0);
946
947         rc = SECSVC_OK;
948
949 out:
950         if (!IS_ERR_OR_NULL(rsip))
951                 rsi_entry_put(rsicache, rsip);
952         if (!IS_ERR_OR_NULL(rscp)) {
953                 /* if anything went wrong, we don't keep the context too */
954                 if (rc != SECSVC_OK)
955                         UC_CACHE_SET_INVALID(rscp->sc_uc_entry);
956                 else
957                         CDEBUG(D_SEC, "%s: create rsc with idx %#llx\n",
958                                target->obd_name,
959                                gss_handle_to_u64(&rscp->sc_handle));
960
961                 rsc_entry_put(rsccache, rscp);
962         }
963         RETURN(rc);
964 }
965
966 struct gss_svc_ctx *gss_svc_upcall_get_ctx(struct ptlrpc_request *req,
967                                            struct gss_wire_ctx *gw)
968 {
969         struct gss_rsc *rscp;
970
971         rscp = gss_svc_searchbyctx(&gw->gw_handle);
972         if (IS_ERR_OR_NULL(rscp)) {
973                 CWARN("Invalid gss ctx idx %#llx from %s\n",
974                       gss_handle_to_u64(&gw->gw_handle),
975                       libcfs_nidstr(&req->rq_peer.nid));
976                 return NULL;
977         }
978
979         return &rscp->sc_ctx;
980 }
981
982 void gss_svc_upcall_put_ctx(struct gss_svc_ctx *ctx)
983 {
984         struct gss_rsc *rscp = container_of(ctx, struct gss_rsc, sc_ctx);
985
986         rsc_entry_put(rsccache, rscp);
987 }
988
989 void gss_svc_upcall_destroy_ctx(struct gss_svc_ctx *ctx)
990 {
991         struct gss_rsc *rscp = container_of(ctx, struct gss_rsc, sc_ctx);
992
993         UC_CACHE_SET_INVALID(rscp->sc_uc_entry);
994         rscp->sc_uc_entry->ue_expire = 1;
995 }
996
997 /* Wait for userspace daemon to open socket, approx 1.5 s.
998  * If socket is not open, upcall requests might fail.
999  */
1000 static int check_gssd_socket(void)
1001 {
1002         struct sockaddr_un *sun;
1003         struct socket *sock;
1004         int tries = 0;
1005         int err;
1006
1007 #ifdef HAVE_SOCK_CREATE_KERN_USE_NET
1008         err = sock_create_kern(current->nsproxy->net_ns,
1009                                AF_UNIX, SOCK_STREAM, 0, &sock);
1010 #else
1011         err = sock_create_kern(AF_UNIX, SOCK_STREAM, 0, &sock);
1012 #endif
1013         if (err < 0) {
1014                 CDEBUG(D_SEC, "Failed to create socket: %d\n", err);
1015                 return err;
1016         }
1017
1018         OBD_ALLOC(sun, sizeof(*sun));
1019         if (!sun) {
1020                 sock_release(sock);
1021                 return -ENOMEM;
1022         }
1023         memset(sun, 0, sizeof(*sun));
1024         sun->sun_family = AF_UNIX;
1025         strncpy(sun->sun_path, GSS_SOCKET_PATH, sizeof(sun->sun_path));
1026
1027         /* Try to connect to the socket */
1028         while (tries++ < 6) {
1029                 err = kernel_connect(sock, (struct sockaddr *)sun,
1030                                      sizeof(*sun), 0);
1031                 if (!err)
1032                         break;
1033                 schedule_timeout_uninterruptible(cfs_time_seconds(1) / 4);
1034         }
1035         if (err < 0)
1036                 CDEBUG(D_SEC, "Failed to connect to socket: %d\n", err);
1037         else
1038                 kernel_sock_shutdown(sock, SHUT_RDWR);
1039
1040         sock_release(sock);
1041         OBD_FREE(sun, sizeof(*sun));
1042         return err;
1043 }
1044
1045 int __init gss_init_svc_upcall(void)
1046 {
1047         int rc;
1048
1049         /*
1050          * this helps reducing context index confliction. after server reboot,
1051          * conflicting request from clients might be filtered out by initial
1052          * sequence number checking, thus no chance to sent error notification
1053          * back to clients.
1054          */
1055         get_random_bytes(&__ctx_index, sizeof(__ctx_index));
1056
1057         rsicache = upcall_cache_init(RSI_CACHE_NAME, RSI_UPCALL_PATH,
1058                                      UC_RSICACHE_HASH_SIZE,
1059                                      3600, /* entry expire: 1 h */
1060                                      30, /* acquire expire: 30 s */
1061                                      false, /* can't replay acquire */
1062                                      &rsi_upcall_cache_ops);
1063         if (IS_ERR(rsicache)) {
1064                 rc = PTR_ERR(rsicache);
1065                 rsicache = NULL;
1066                 return rc;
1067         }
1068         rsccache = upcall_cache_init(RSC_CACHE_NAME, RSC_UPCALL_PATH,
1069                                      UC_RSCCACHE_HASH_SIZE,
1070                                      3600, /* replaced with one from mech */
1071                                      100, /* arbitrary, not used */
1072                                      false, /* can't replay acquire */
1073                                      &rsc_upcall_cache_ops);
1074         if (IS_ERR(rsccache)) {
1075                 upcall_cache_cleanup(rsicache);
1076                 rsicache = NULL;
1077                 rc = PTR_ERR(rsccache);
1078                 rsccache = NULL;
1079                 return rc;
1080         }
1081
1082         if (check_gssd_socket())
1083                 CDEBUG(D_SEC,
1084                        "Init channel not opened by lsvcgssd, GSS might not work on server side until daemon is active\n");
1085
1086         return 0;
1087 }
1088
1089 void gss_exit_svc_upcall(void)
1090 {
1091         upcall_cache_cleanup(rsicache);
1092         upcall_cache_cleanup(rsccache);
1093 }