Whamcloud - gitweb
LU-16977 utils: access_log_reader accesses beyond batch array
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
1 /*
2  * svc_in_gssd_proc.c
3  *
4  * Copyright (c) 2000 The Regents of the University of Michigan.
5  * All rights reserved.
6  *
7  * Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  * 3. Neither the name of the University nor the names of its
19  *    contributors may be used to endorse or promote products derived
20  *    from this software without specific prior written permission.
21  *
22  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
23  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25  * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
26  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29  * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/stat.h>
37
38 #include <inttypes.h>
39 #include <pwd.h>
40 #include <stdio.h>
41 #include <unistd.h>
42 #include <ctype.h>
43 #include <string.h>
44 #include <fcntl.h>
45 #include <errno.h>
46 #ifdef HAVE_NETDB_H
47 # include <netdb.h>
48 #endif
49
50 #include <stdbool.h>
51
52 #include "svcgssd.h"
53 #include "gss_util.h"
54 #include "err_util.h"
55 #include "context.h"
56 #include "cacheio.h"
57 #include "lsupport.h"
58 #include "gss_oids.h"
59 #include <time.h>
60 #include <linux/lustre/lustre_idl.h>
61 #include "sk_utils.h"
62 #include <sys/time.h>
63 #include <gssapi/gssapi_krb5.h>
64
65 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.sptlrpc.context/channel"
66 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.sptlrpc.init/channel"
67
68 #define TOKEN_BUF_SIZE          8192
69
70 struct svc_cred {
71         uint32_t cr_remote;
72         uint32_t cr_usr_root;
73         uint32_t cr_usr_mds;
74         uint32_t cr_usr_oss;
75         uid_t    cr_uid;
76         uid_t    cr_mapped_uid;
77         uid_t    cr_gid;
78 };
79
80 struct svc_nego_data {
81         /* kernel data*/
82         uint32_t        lustre_svc;
83         lnet_nid_t      nid;
84         uint64_t        handle_seq;
85         char            nm_name[LUSTRE_NODEMAP_NAME_LENGTH + 1];
86         gss_buffer_desc in_tok;
87         gss_buffer_desc out_tok;
88         gss_buffer_desc in_handle;
89         gss_buffer_desc out_handle;
90         uint32_t        maj_stat;
91         uint32_t        min_stat;
92
93         /* userspace data */
94         gss_OID                 mech;
95         gss_ctx_id_t            ctx;
96         gss_buffer_desc         ctx_token;
97 };
98
99 static int
100 do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
101                 gss_OID mechoid, gss_buffer_desc *context_token)
102 {
103         FILE *f;
104         const char *mechname;
105         int err;
106
107         printerr(LL_INFO, "doing downcall\n");
108         mechname = gss_OID_mech_name(mechoid);
109         if (mechname == NULL)
110                 goto out_err;
111         f = fopen(SVCGSSD_CONTEXT_CHANNEL, "w");
112         if (f == NULL) {
113                 printerr(LL_ERR, "ERROR: unable to open downcall channel "
114                              "%s: %s\n",
115                              SVCGSSD_CONTEXT_CHANNEL, strerror(errno));
116                 goto out_err;
117         }
118         qword_printhex(f, out_handle->value, out_handle->length);
119         /* XXX are types OK for the rest of this? */
120         qword_printint(f, time(NULL) + 3600);   /* 1 hour should be ok */
121         qword_printint(f, cred->cr_remote);
122         qword_printint(f, cred->cr_usr_root);
123         qword_printint(f, cred->cr_usr_mds);
124         qword_printint(f, cred->cr_usr_oss);
125         qword_printint(f, cred->cr_mapped_uid);
126         qword_printint(f, cred->cr_uid);
127         qword_printint(f, cred->cr_gid);
128         qword_print(f, mechname);
129         qword_printhex(f, context_token->value, context_token->length);
130         err = qword_eol(f);
131         fclose(f);
132         return err;
133 out_err:
134         printerr(LL_ERR, "ERROR: downcall failed\n");
135         return -1;
136 }
137
138 struct gss_verifier {
139         u_int32_t       flav;
140         gss_buffer_desc body;
141 };
142
143 #define RPCSEC_GSS_SEQ_WIN      5
144
145 static int
146 send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
147               u_int32_t maj_stat, u_int32_t min_stat,
148               gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
149 {
150         char buf[2 * TOKEN_BUF_SIZE];
151         char *bp = buf;
152         int blen = sizeof(buf);
153         /* XXXARG: */
154         int g;
155
156         printerr(LL_INFO, "sending reply\n");
157         qword_addhex(&bp, &blen, in_handle->value, in_handle->length);
158         qword_addhex(&bp, &blen, in_token->value, in_token->length);
159         qword_addint(&bp, &blen, time(NULL) + 3600);   /* 1 hour should be ok */
160         qword_adduint(&bp, &blen, maj_stat);
161         qword_adduint(&bp, &blen, min_stat);
162         qword_addhex(&bp, &blen, out_handle->value, out_handle->length);
163         qword_addhex(&bp, &blen, out_token->value, out_token->length);
164         qword_addeol(&bp, &blen);
165         if (blen <= 0) {
166                 printerr(LL_ERR, "ERROR: %s: message too long\n", __func__);
167                 return -1;
168         }
169         g = open(SVCGSSD_INIT_CHANNEL, O_WRONLY);
170         if (g == -1) {
171                 printerr(LL_ERR, "ERROR: %s: open %s failed: %s\n",
172                          __func__, SVCGSSD_INIT_CHANNEL, strerror(errno));
173                 return -1;
174         }
175         *bp = '\0';
176         printerr(LL_DEBUG, "writing message: %s", buf);
177         if (write(g, buf, bp - buf) == -1) {
178                 printerr(LL_ERR, "ERROR: %s: failed to write message\n",
179                          __func__);
180                 close(g);
181                 return -1;
182         }
183         close(g);
184         return 0;
185 }
186
187 #define rpc_auth_ok                     0
188 #define rpc_autherr_badcred             1
189 #define rpc_autherr_rejectedcred        2
190 #define rpc_autherr_badverf             3
191 #define rpc_autherr_rejectedverf        4
192 #define rpc_autherr_tooweak             5
193 #define rpcsec_gsserr_credproblem       13
194 #define rpcsec_gsserr_ctxproblem        14
195
196 static int lookup_localname(gss_name_t client_name, char *princ, lnet_nid_t nid,
197                             uid_t *uid)
198 {
199         u_int32_t maj_stat, min_stat;
200         gss_buffer_desc localname;
201         char *sname;
202         int rc = -1;
203
204         *uid = -1;
205         maj_stat = gss_localname(&min_stat, client_name, GSS_C_NO_OID,
206                                  &localname);
207         if (maj_stat != GSS_S_COMPLETE) {
208                 printerr(LL_INFO, "no local name for %s/%#Lx\n", princ, nid);
209                 return rc;
210         }
211
212         sname = calloc(localname.length + 1, 1);
213         if (!sname) {
214                 printerr(LL_ERR, "%s: error allocating %zu bytes\n",
215                          __func__, localname.length + 1);
216                 goto free;
217         }
218         memcpy(sname, localname.value, localname.length);
219         sname[localname.length] = '\0';
220
221         *uid = parse_uid(sname);
222         free(sname);
223         printerr(LL_WARN, "found local uid: %s ==> %d\n", princ, *uid);
224         rc = 0;
225
226 free:
227         gss_release_buffer(&min_stat, &localname);
228         return rc;
229 }
230
231 static int lookup_id(gss_name_t client_name, char *princ, lnet_nid_t nid,
232                      uid_t *uid)
233 {
234         if (!mapping_empty())
235                 return lookup_mapping(princ, nid, uid);
236
237         return lookup_localname(client_name, princ, nid, uid);
238 }
239
240 static int
241 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred,
242         lnet_nid_t nid, uint32_t lustre_svc)
243 {
244         u_int32_t       maj_stat, min_stat;
245         gss_buffer_desc name;
246         char            *sname, *host, *realm;
247         const int       namebuf_size = 512;
248         char            namebuf[namebuf_size];
249         int             res = -1;
250         gss_OID         name_type = GSS_C_NO_OID;
251         struct passwd   *pw;
252
253         cred->cr_remote = 0;
254         cred->cr_usr_root = cred->cr_usr_mds = cred->cr_usr_oss = 0;
255         cred->cr_uid = cred->cr_mapped_uid = cred->cr_gid = -1;
256
257         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
258         if (maj_stat != GSS_S_COMPLETE) {
259                 pgsserr("get_ids: gss_display_name",
260                         maj_stat, min_stat, mech);
261                 return -1;
262         }
263         /* be certain name.length+1 doesn't overflow */
264         if (name.length >= 0xffff ||
265             !(sname = calloc(name.length + 1, 1))) {
266                 printerr(LL_ERR,
267                          "ERROR: %s: error allocating %zu bytes for sname\n",
268                          __func__, name.length + 1);
269                 gss_release_buffer(&min_stat, &name);
270                 return -1;
271         }
272         memcpy(sname, name.value, name.length);
273         sname[name.length] = '\0';
274         gss_release_buffer(&min_stat, &name);
275
276         if (lustre_svc == LUSTRE_GSS_SVC_MDS &&
277             lookup_id(client_name, sname, nid, &cred->cr_mapped_uid))
278                 printerr(LL_DEBUG, "no id found for %s\n", sname);
279
280         realm = strchr(sname, '@');
281         if (realm) {
282                 *realm++ = '\0';
283         } else {
284                 printerr(LL_ERR, "ERROR: %s has no realm name\n", sname);
285                 goto out_free;
286         }
287
288         host = strchr(sname, '/');
289         if (host)
290                 *host++ = '\0';
291
292         if (strcmp(sname, GSSD_SERVICE_MGS) == 0) {
293                 printerr(LL_ERR, "forbid %s as a user name\n", sname);
294                 goto out_free;
295         }
296
297         /* 1. check host part */
298         if (host) {
299                 if (lnet_nid2hostname(nid, namebuf, namebuf_size)) {
300                         printerr(LL_ERR,
301                                  "ERROR: failed to resolve hostname for %s/%s@%s from %016llx\n",
302                                  sname, host, realm, nid);
303                         goto out_free;
304                 }
305
306                 if (strcasecmp(host, namebuf)) {
307                         printerr(LL_ERR,
308                                  "ERROR: %s/%s@%s claimed hostname doesn't match %s, nid %016llx\n",
309                                  sname, host, realm,
310                                  namebuf, nid);
311                         goto out_free;
312                 }
313         } else {
314                 if (!strcmp(sname, GSSD_SERVICE_MDS) ||
315                     !strcmp(sname, GSSD_SERVICE_OSS)) {
316                         printerr(LL_ERR,
317                                  "ERROR: %s@%s from %016llx doesn't bind with hostname\n",
318                                  sname, realm, nid);
319                         goto out_free;
320                 }
321         }
322
323         /* 2. check realm and user */
324         switch (lustre_svc) {
325         case LUSTRE_GSS_SVC_MDS:
326                 if (strcasecmp(mds_local_realm, realm) != 0) {
327                         /* Remote realm case */
328                         cred->cr_remote = 1;
329
330                         /* Prevent access to unmapped user from remote realm */
331                         if (cred->cr_mapped_uid == -1) {
332                                 printerr(LL_ERR,
333                                          "ERROR: %s%s%s@%s from %016llx is remote but without mapping\n",
334                                          sname, host ? "/" : "",
335                                          host ? host : "", realm, nid);
336                                 break;
337                         }
338                         goto valid;
339                 }
340
341                 /* Now we know we are dealing with a local realm */
342
343                 if (!strcmp(sname, LUSTRE_ROOT_NAME) ||
344                     !strcmp(sname, GSSD_SERVICE_HOST)) {
345                         cred->cr_uid = 0;
346                         cred->cr_usr_root = 1;
347                         goto valid;
348                 }
349                 if (!strcmp(sname, GSSD_SERVICE_MDS)) {
350                         cred->cr_uid = 0;
351                         cred->cr_usr_mds = 1;
352                         goto valid;
353                 }
354                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
355                         cred->cr_uid = 0;
356                         cred->cr_usr_oss = 1;
357                         goto valid;
358                 }
359                 if (cred->cr_mapped_uid != -1) {
360                         printerr(LL_INFO,
361                                  "user %s from %016llx is mapped to %u\n",
362                                  sname, nid,
363                                  cred->cr_mapped_uid);
364                         goto valid;
365                 }
366                 pw = getpwnam(sname);
367                 if (pw != NULL) {
368                         cred->cr_uid = pw->pw_uid;
369                         printerr(LL_INFO, "%s resolve to uid %u\n",
370                                  sname, cred->cr_uid);
371                         goto valid;
372                 }
373                 printerr(LL_ERR, "ERROR: invalid user, %s/%s@%s from %016llx\n",
374                          sname, host, realm, nid);
375                 break;
376
377 valid:
378                 res = 0;
379                 break;
380         case LUSTRE_GSS_SVC_MGS:
381                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
382                         cred->cr_uid = 0;
383                         cred->cr_usr_oss = 1;
384                 }
385                 fallthrough;
386         case LUSTRE_GSS_SVC_OSS:
387                 if (!strcmp(sname, LUSTRE_ROOT_NAME) ||
388                     !strcmp(sname, GSSD_SERVICE_HOST)) {
389                         cred->cr_uid = 0;
390                         cred->cr_usr_root = 1;
391                 } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
392                         cred->cr_uid = 0;
393                         cred->cr_usr_mds = 1;
394                 }
395                 if (cred->cr_uid == -1) {
396                         printerr(LL_ERR,
397                                  "ERROR: svc %d doesn't accept user %s from %016llx\n",
398                                  lustre_svc, sname, nid);
399                         break;
400                 }
401                 res = 0;
402                 break;
403         default:
404                 assert(0);
405         }
406
407 out_free:
408         if (!res)
409                 printerr(LL_WARN, "%s: authenticated %s%s%s@%s from %016llx\n",
410                          lustre_svc_name[lustre_svc], sname,
411                          host ? "/" : "", host ? host : "", realm, nid);
412         free(sname);
413         return res;
414 }
415
416 typedef struct gss_union_ctx_id_t {
417         gss_OID         mech_type;
418         gss_ctx_id_t    internal_ctx_id;
419 } gss_union_ctx_id_desc, *gss_union_ctx_id_t;
420
421 int handle_sk(struct svc_nego_data *snd)
422 {
423 #ifdef HAVE_OPENSSL_SSK
424         struct sk_cred *skc = NULL;
425         struct svc_cred cred;
426         gss_buffer_desc bufs[SK_INIT_BUFFERS];
427         gss_buffer_desc remote_pub_key = GSS_C_EMPTY_BUFFER;
428         char *target;
429         uint32_t rc = GSS_S_DEFECTIVE_TOKEN;
430         uint32_t version;
431         uint32_t flags;
432         int i;
433         int attempts = 0;
434
435         printerr(LL_DEBUG, "Handling sk request\n");
436         memset(bufs, 0, sizeof(gss_buffer_desc) * SK_INIT_BUFFERS);
437
438         /* See lgss_sk_using_cred() for client side token formation.
439          * Decoding initiator buffers */
440         i = sk_decode_netstring(bufs, SK_INIT_BUFFERS, &snd->in_tok);
441         if (i < SK_INIT_BUFFERS) {
442                 printerr(LL_ERR,
443                          "Invalid netstring token received from peer\n");
444                 goto cleanup_buffers;
445         }
446
447         /* Allowing for a larger length first buffer in the future */
448         if (bufs[SK_INIT_VERSION].length < sizeof(version)) {
449                 printerr(LL_ERR, "Invalid version received (wrong size)\n");
450                 goto cleanup_buffers;
451         }
452         memcpy(&version, bufs[SK_INIT_VERSION].value, sizeof(version));
453         version = be32toh(version);
454         if (version != SK_MSG_VERSION) {
455                 printerr(LL_ERR, "Invalid version received: %d\n", version);
456                 goto cleanup_buffers;
457         }
458
459         rc = GSS_S_FAILURE;
460
461         /* target must be a null terminated string */
462         i = bufs[SK_INIT_TARGET].length - 1;
463         target = bufs[SK_INIT_TARGET].value;
464         if (i >= 0 && target[i] != '\0') {
465                 printerr(LL_ERR, "Invalid target from netstring\n");
466                 goto cleanup_buffers;
467         }
468
469         if (bufs[SK_INIT_FLAGS].length != sizeof(flags)) {
470                 printerr(LL_ERR, "Invalid flags from netstring\n");
471                 goto cleanup_buffers;
472         }
473         memcpy(&flags, bufs[SK_INIT_FLAGS].value, sizeof(flags));
474
475         skc = sk_create_cred(target, snd->nm_name, be32toh(flags));
476         if (!skc) {
477                 printerr(LL_ERR, "Failed to create sk credentials\n");
478                 goto cleanup_buffers;
479         }
480
481         /* Verify that the peer has used a prime size greater or equal to
482          * the size specified in the key file which may contain only zero
483          * fill but the size specifies the mimimum supported size on
484          * servers */
485         if (skc->sc_flags & LGSS_SVC_PRIV &&
486             bufs[SK_INIT_P].length < skc->sc_p.length) {
487                 printerr(LL_ERR,
488                          "Peer DHKE prime does not meet the size required by keyfile: %zd bits\n",
489                          skc->sc_p.length * 8);
490                 goto cleanup_buffers;
491         }
492
493         /* Throw out the p from the server and use the wire data */
494         free(skc->sc_p.value);
495         skc->sc_p.value = NULL;
496         skc->sc_p.length = 0;
497
498         /* Take control of all the allocated buffers from decoding */
499         if (bufs[SK_INIT_RANDOM].length !=
500             sizeof(skc->sc_kctx.skc_peer_random)) {
501                 printerr(LL_ERR, "Invalid size for client random\n");
502                 goto cleanup_buffers;
503         }
504
505         memcpy(&skc->sc_kctx.skc_peer_random, bufs[SK_INIT_RANDOM].value,
506                sizeof(skc->sc_kctx.skc_peer_random));
507         skc->sc_p = bufs[SK_INIT_P];
508         remote_pub_key = bufs[SK_INIT_PUB_KEY];
509         skc->sc_nodemap_hash = bufs[SK_INIT_NODEMAP];
510         skc->sc_hmac = bufs[SK_INIT_HMAC];
511
512         /* Verify HMAC from peer.  Ideally this would happen before anything
513          * else but we don't have enough information to lookup key without the
514          * token (fsname and cluster_hash) so it's done after. */
515         rc = sk_verify_hmac(skc, bufs, SK_INIT_BUFFERS - 1, EVP_sha256(),
516                             &skc->sc_hmac);
517         if (rc != GSS_S_COMPLETE) {
518                 printerr(LL_ERR, "HMAC verification error: 0x%x from peer %s\n",
519                          rc, libcfs_nid2str((lnet_nid_t)snd->nid));
520                 goto cleanup_partial;
521         }
522
523         /* Check that the cluster hash matches the hash of nodemap name */
524         rc = sk_verify_hash(snd->nm_name, EVP_sha256(), &skc->sc_nodemap_hash);
525         if (rc != GSS_S_COMPLETE) {
526                 printerr(LL_ERR, "Cluster hash failed validation: 0x%x\n", rc);
527                 goto cleanup_partial;
528         }
529
530 redo:
531         rc = sk_gen_params(skc, sk_dh_checks);
532         if (rc != GSS_S_COMPLETE) {
533                 printerr(LL_ERR,
534                          "Failed to generate DH params for responder\n");
535                 goto cleanup_partial;
536         }
537         rc = sk_compute_dh_key(skc, &remote_pub_key);
538         if (rc == GSS_S_BAD_QOP && attempts < 2) {
539                 /* GSS_S_BAD_QOP means the generated shared key was shorter
540                  * than expected. Just retry twice before giving up.
541                  */
542                 attempts++;
543                 if (skc->sc_params) {
544                         EVP_PKEY_free(skc->sc_params);
545                         skc->sc_params = NULL;
546                 }
547                 if (skc->sc_pub_key.value) {
548                         free(skc->sc_pub_key.value);
549                         skc->sc_pub_key.value = NULL;
550                 }
551                 skc->sc_pub_key.length = 0;
552                 if (skc->sc_dh_shared_key.value) {
553                         /* erase secret key before freeing memory */
554                         memset(skc->sc_dh_shared_key.value, 0,
555                                skc->sc_dh_shared_key.length);
556                         free(skc->sc_dh_shared_key.value);
557                         skc->sc_dh_shared_key.value = NULL;
558                 }
559                 skc->sc_dh_shared_key.length = 0;
560                 goto redo;
561         } else if (rc != GSS_S_COMPLETE) {
562                 printerr(LL_ERR,
563                          "Failed to compute session key from DH params\n");
564                 goto cleanup_partial;
565         }
566
567         /* Cleanup init buffers we have copied or don't need anymore */
568         free(bufs[SK_INIT_VERSION].value);
569         free(bufs[SK_INIT_RANDOM].value);
570         free(bufs[SK_INIT_TARGET].value);
571         free(bufs[SK_INIT_FLAGS].value);
572
573         /* Server reply contains the servers public key, random,  and HMAC */
574         version = htobe32(SK_MSG_VERSION);
575         bufs[SK_RESP_VERSION].value = &version;
576         bufs[SK_RESP_VERSION].length = sizeof(version);
577         bufs[SK_RESP_RANDOM].value = &skc->sc_kctx.skc_host_random;
578         bufs[SK_RESP_RANDOM].length = sizeof(skc->sc_kctx.skc_host_random);
579         bufs[SK_RESP_PUB_KEY] = skc->sc_pub_key;
580         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs,
581                          SK_RESP_BUFFERS - 1, EVP_sha256(),
582                          &skc->sc_hmac)) {
583                 printerr(LL_ERR, "Failed to sign parameters\n");
584                 goto out_err;
585         }
586         bufs[SK_RESP_HMAC] = skc->sc_hmac;
587         if (sk_encode_netstring(bufs, SK_RESP_BUFFERS, &snd->out_tok)) {
588                 printerr(LL_ERR, "Failed to encode netstring for token\n");
589                 goto out_err;
590         }
591         printerr(LL_INFO, "Created netstring of %zd bytes\n",
592                  snd->out_tok.length);
593
594         if (sk_session_kdf(skc, snd->nid, &snd->in_tok, &snd->out_tok)) {
595                 printerr(LL_ERR, "Failed to calculate derived session key\n");
596                 goto out_err;
597         }
598         if (sk_compute_keys(skc)) {
599                 printerr(LL_ERR,
600                          "Failed to compute HMAC and encryption keys\n");
601                 goto out_err;
602         }
603         if (sk_serialize_kctx(skc, &snd->ctx_token)) {
604                 printerr(LL_ERR, "Failed to serialize context for kernel\n");
605                 goto out_err;
606         }
607
608         snd->out_handle.length = sizeof(snd->handle_seq);
609         memcpy(snd->out_handle.value, &snd->handle_seq,
610                sizeof(snd->handle_seq));
611         snd->maj_stat = GSS_S_COMPLETE;
612
613         /* fix credentials */
614         memset(&cred, 0, sizeof(cred));
615         cred.cr_mapped_uid = -1;
616
617         if (skc->sc_flags & LGSS_ROOT_CRED_ROOT)
618                 cred.cr_usr_root = 1;
619         if (skc->sc_flags & LGSS_ROOT_CRED_MDT)
620                 cred.cr_usr_mds = 1;
621         if (skc->sc_flags & LGSS_ROOT_CRED_OST)
622                 cred.cr_usr_oss = 1;
623
624         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
625
626         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
627         free(remote_pub_key.value);
628         free(snd->ctx_token.value);
629         snd->ctx_token.length = 0;
630
631         printerr(LL_DEBUG, "sk returning success\n");
632         return 0;
633
634 cleanup_buffers:
635         for (i = 0; i < SK_INIT_BUFFERS; i++)
636                 free(bufs[i].value);
637         sk_free_cred(skc);
638         snd->maj_stat = rc;
639         return -1;
640
641 cleanup_partial:
642         free(bufs[SK_INIT_VERSION].value);
643         free(bufs[SK_INIT_RANDOM].value);
644         free(bufs[SK_INIT_TARGET].value);
645         free(bufs[SK_INIT_FLAGS].value);
646         free(remote_pub_key.value);
647         sk_free_cred(skc);
648         snd->maj_stat = rc;
649         return -1;
650
651 out_err:
652         snd->maj_stat = rc;
653         if (snd->ctx_token.value) {
654                 free(snd->ctx_token.value);
655                 snd->ctx_token.value = 0;
656                 snd->ctx_token.length = 0;
657         }
658         free(remote_pub_key.value);
659         sk_free_cred(skc);
660         printerr(LL_DEBUG, "sk returning failure\n");
661 #else /* !HAVE_OPENSSL_SSK */
662         printerr(LL_ERR, "ERROR: shared key subflavour is not enabled\n");
663 #endif /* HAVE_OPENSSL_SSK */
664         return -1;
665 }
666
667 int handle_null(struct svc_nego_data *snd)
668 {
669         struct svc_cred cred;
670         uint64_t tmp;
671         uint32_t flags;
672
673         /* null just uses the same token as the return token and for
674          * for sending to the kernel.  It is a single uint64_t. */
675         if (snd->in_tok.length != sizeof(uint64_t)) {
676                 snd->maj_stat = GSS_S_DEFECTIVE_TOKEN;
677                 printerr(LL_ERR, "Invalid token size (%zd) received\n",
678                          snd->in_tok.length);
679                 return -1;
680         }
681         snd->out_tok.length = snd->in_tok.length;
682         snd->out_tok.value = malloc(snd->out_tok.length);
683         if (!snd->out_tok.value) {
684                 snd->maj_stat = GSS_S_FAILURE;
685                 printerr(LL_ERR, "Failed to allocate out_tok\n");
686                 return -1;
687         }
688
689         snd->ctx_token.length = snd->in_tok.length;
690         snd->ctx_token.value = malloc(snd->ctx_token.length);
691         if (!snd->ctx_token.value) {
692                 snd->maj_stat = GSS_S_FAILURE;
693                 printerr(LL_ERR, "Failed to allocate ctx_token\n");
694                 return -1;
695         }
696
697         snd->out_handle.length = sizeof(snd->handle_seq);
698         memcpy(snd->out_handle.value, &snd->handle_seq,
699                sizeof(snd->handle_seq));
700         snd->maj_stat = GSS_S_COMPLETE;
701
702         memcpy(&tmp, snd->in_tok.value, sizeof(tmp));
703         tmp = be64toh(tmp);
704         flags = (uint32_t)(tmp & 0x00000000ffffffff);
705         memset(&cred, 0, sizeof(cred));
706         cred.cr_mapped_uid = -1;
707
708         if (flags & LGSS_ROOT_CRED_ROOT)
709                 cred.cr_usr_root = 1;
710         if (flags & LGSS_ROOT_CRED_MDT)
711                 cred.cr_usr_mds = 1;
712         if (flags & LGSS_ROOT_CRED_OST)
713                 cred.cr_usr_oss = 1;
714
715         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
716
717         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
718         free(snd->ctx_token.value);
719         snd->ctx_token.length = 0;
720
721         return 0;
722 }
723
724 static int handle_krb(struct svc_nego_data *snd)
725 {
726         u_int32_t               ret_flags;
727         gss_name_t              client_name;
728         gss_buffer_desc         ignore_out_tok = {.value = NULL};
729         gss_OID                 mech = GSS_C_NO_OID;
730         gss_cred_id_t           svc_cred;
731         u_int32_t               ignore_min_stat;
732         struct svc_cred         cred;
733
734         svc_cred = gssd_select_svc_cred(snd->lustre_svc);
735         if (!svc_cred) {
736                 printerr(LL_ERR, "no service credential for svc %u\n",
737                          snd->lustre_svc);
738                 goto out_err;
739         }
740
741         snd->maj_stat = gss_accept_sec_context(&snd->min_stat, &snd->ctx,
742                                                svc_cred, &snd->in_tok,
743                                                GSS_C_NO_CHANNEL_BINDINGS,
744                                                &client_name, &mech,
745                                                &snd->out_tok, &ret_flags, NULL,
746                                                NULL);
747
748         if (snd->maj_stat == GSS_S_CONTINUE_NEEDED) {
749                 printerr(LL_WARN,
750                          "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
751
752                 /* Save the context handle for future calls */
753                 snd->out_handle.length = sizeof(snd->ctx);
754                 memcpy(snd->out_handle.value, &snd->ctx, sizeof(snd->ctx));
755                 return 0;
756         } else if (snd->maj_stat != GSS_S_COMPLETE) {
757                 printerr(LL_ERR, "ERROR: gss_accept_sec_context failed\n");
758                 pgsserr("handle_krb: gss_accept_sec_context",
759                         snd->maj_stat, snd->min_stat, mech);
760                 goto out_err;
761         }
762
763         if (get_ids(client_name, mech, &cred, snd->nid, snd->lustre_svc)) {
764                 /* get_ids() prints error msg */
765                 snd->maj_stat = GSS_S_BAD_NAME; /* XXX ? */
766                 gss_release_name(&ignore_min_stat, &client_name);
767                 goto out_err;
768         }
769         gss_release_name(&ignore_min_stat, &client_name);
770
771         /* Context complete. Pass handle_seq in out_handle to use
772          * for context lookup in the kernel. */
773         snd->out_handle.length = sizeof(snd->handle_seq);
774         memcpy(snd->out_handle.value, &snd->handle_seq,
775                sizeof(snd->handle_seq));
776
777         /* kernel needs ctx to calculate verifier on null response, so
778          * must give it context before doing null call: */
779         if (serialize_context_for_kernel(snd->ctx, &snd->ctx_token, mech)) {
780                 printerr(LL_ERR,
781                          "ERROR: %s: serialize_context_for_kernel failed\n",
782                         __func__);
783                 snd->maj_stat = GSS_S_FAILURE;
784                 goto out_err;
785         }
786         /* We no longer need the gss context */
787         gss_delete_sec_context(&ignore_min_stat, &snd->ctx, &ignore_out_tok);
788         do_svc_downcall(&snd->out_handle, &cred, mech, &snd->ctx_token);
789
790         return 0;
791
792 out_err:
793         if (snd->ctx != GSS_C_NO_CONTEXT)
794                 gss_delete_sec_context(&ignore_min_stat, &snd->ctx,
795                                        &ignore_out_tok);
796
797         return 1;
798 }
799
800 /*
801  * return -1 only if we detect error during reading from upcall channel,
802  * all other cases return 0.
803  */
804 int handle_channel_request(FILE *f)
805 {
806         char                    in_tok_buf[TOKEN_BUF_SIZE];
807         char                    in_handle_buf[15];
808         char                    out_handle_buf[15];
809         gss_buffer_desc         ctx_token      = {.value = NULL},
810                                 null_token     = {.value = NULL};
811         uint32_t                lustre_mech;
812         static char             *lbuf;
813         static int              lbuflen;
814         static char             *cp;
815         int                     get_len;
816         int                     rc = 1;
817         u_int32_t               ignore_min_stat;
818         struct svc_nego_data    snd = {
819                 .in_tok.value           = in_tok_buf,
820                 .in_handle.value        = in_handle_buf,
821                 .out_handle.value       = out_handle_buf,
822                 .maj_stat               = GSS_S_FAILURE,
823                 .ctx                    = GSS_C_NO_CONTEXT,
824         };
825
826         printerr(LL_INFO, "handling request\n");
827         if (readline(fileno(f), &lbuf, &lbuflen) != 1) {
828                 printerr(LL_ERR, "ERROR: failed reading request\n");
829                 return -1;
830         }
831
832         cp = lbuf;
833
834         /* see rsi_request() for the format of data being input here */
835         qword_get(&cp, (char *)&snd.lustre_svc, sizeof(snd.lustre_svc));
836
837         /* lustre_svc is the svc and gss subflavor */
838         lustre_mech = (snd.lustre_svc & LUSTRE_GSS_MECH_MASK) >>
839                       LUSTRE_GSS_MECH_SHIFT;
840         snd.lustre_svc = snd.lustre_svc & LUSTRE_GSS_SVC_MASK;
841         switch (lustre_mech) {
842         case LGSS_MECH_KRB5:
843                 if (!krb_enabled) {
844                         static time_t next_krb;
845
846                         if (time(NULL) > next_krb) {
847                                 printerr(LL_WARN,
848                                          "warning: Request for kerberos but service support not enabled\n");
849                                 next_krb = time(NULL) + 3600;
850                         }
851                         goto ignore;
852                 }
853                 snd.mech = &krb5oid;
854                 break;
855         case LGSS_MECH_NULL:
856                 if (!null_enabled) {
857                         static time_t next_null;
858
859                         if (time(NULL) > next_null) {
860                                 printerr(LL_WARN,
861                                          "warning: Request for gssnull but service support not enabled\n");
862                                 next_null = time(NULL) + 3600;
863                         }
864                         goto ignore;
865                 }
866                 snd.mech = &nulloid;
867                 break;
868         case LGSS_MECH_SK:
869                 if (!sk_enabled) {
870                         static time_t next_ssk;
871
872                         if (time(NULL) > next_ssk) {
873                                 printerr(LL_WARN,
874                                          "warning: Request for SSK but service support not %s\n",
875 #ifdef HAVE_OPENSSL_SSK
876                                          "enabled"
877 #else
878                                          "included"
879 #endif
880                                         );
881                                 next_ssk = time(NULL) + 3600;
882                         }
883
884                         goto ignore;
885                 }
886                 snd.mech = &skoid;
887                 break;
888         default:
889                 printerr(LL_ERR, "WARNING: invalid mechanism recevied: %d\n",
890                          lustre_mech);
891                 goto out_err;
892                 break;
893         }
894
895         qword_get(&cp, (char *)&snd.nid, sizeof(snd.nid));
896         qword_get(&cp, (char *)&snd.handle_seq, sizeof(snd.handle_seq));
897         qword_get(&cp, snd.nm_name, sizeof(snd.nm_name));
898         printerr(LL_INFO,
899                  "handling req: svc %u, nid %016llx, idx %"PRIx64" nodemap %s\n",
900                  snd.lustre_svc, snd.nid, snd.handle_seq, snd.nm_name);
901
902         get_len = qword_get(&cp, snd.in_handle.value, sizeof(in_handle_buf));
903         if (get_len < 0) {
904                 printerr(LL_ERR, "ERROR: failed parsing request\n");
905                 goto out_err;
906         }
907         snd.in_handle.length = (size_t)get_len;
908
909         printerr(LL_DEBUG, "in_handle:\n");
910         print_hexl(3, snd.in_handle.value, snd.in_handle.length);
911
912         get_len = qword_get(&cp, snd.in_tok.value, sizeof(in_tok_buf));
913         if (get_len < 0) {
914                 printerr(LL_ERR, "ERROR: failed parsing request\n");
915                 goto out_err;
916         }
917         snd.in_tok.length = (size_t)get_len;
918
919         printerr(LL_DEBUG, "in_tok:\n");
920         print_hexl(3, snd.in_tok.value, snd.in_tok.length);
921
922         if (snd.in_handle.length != 0) { /* CONTINUE_INIT case */
923                 if (snd.in_handle.length != sizeof(snd.ctx)) {
924                         printerr(LL_ERR,
925                                  "ERROR: input handle has unexpected length %zu\n",
926                                  snd.in_handle.length);
927                         goto out_err;
928                 }
929                 /* in_handle is the context id stored in the out_handle
930                  * for the GSS_S_CONTINUE_NEEDED case below.  */
931                 memcpy(&snd.ctx, snd.in_handle.value, snd.in_handle.length);
932         }
933
934         if (lustre_mech == LGSS_MECH_KRB5)
935                 rc = handle_krb(&snd);
936         else if (lustre_mech == LGSS_MECH_SK)
937                 rc = handle_sk(&snd);
938         else if (lustre_mech == LGSS_MECH_NULL)
939                 rc = handle_null(&snd);
940         else
941                 printerr(LL_ERR,
942                          "ERROR: Received or request for subflavor that is not enabled: %d\n",
943                          lustre_mech);
944
945 out_err:
946         /* Failures send a null token */
947         if (rc == 0)
948                 send_response(f, &snd.in_handle, &snd.in_tok, snd.maj_stat,
949                               snd.min_stat, &snd.out_handle, &snd.out_tok);
950         else
951                 send_response(f, &snd.in_handle, &snd.in_tok, snd.maj_stat,
952                               snd.min_stat, &null_token, &null_token);
953
954         /* cleanup buffers */
955         if (snd.ctx_token.value != NULL)
956                 free(ctx_token.value);
957         if (snd.out_tok.value != NULL)
958                 gss_release_buffer(&ignore_min_stat, &snd.out_tok);
959
960         /* For junk wire data just ignore */
961 ignore:
962         return 0;
963 }