Whamcloud - gitweb
add some error message, try to catch the occasional lsd upcall failure.
[fs/lustre-release.git] / lustre / utils / lsd_upcall.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  *  Copyright (C) 2004 Cluster File Systems, Inc.
5  *
6  *   This file is part of Lustre, http://www.lustre.org.
7  *
8  *   Lustre is free software; you can redistribute it and/or
9  *   modify it under the terms of version 2 of the GNU General Public
10  *   License as published by the Free Software Foundation.
11  *
12  *   Lustre is distributed in the hope that it will be useful,
13  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *   GNU General Public License for more details.
16  *
17  *   You should have received a copy of the GNU General Public License
18  *   along with Lustre; if not, write to the Free Software
19  *   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
20  *
21  */
22
23 #include <stdlib.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <unistd.h>
27 #include <errno.h>
28 #include <string.h>
29 #include <fcntl.h>
30 #include <pwd.h>
31 #include <grp.h>
32 #include <syslog.h>
33
34 #include <liblustre.h>
35 #include <linux/lustre_idl.h>
36 #include <linux/obd.h>
37 #include <linux/lustre_mds.h>
38
39 #include <portals/types.h>
40 #include <portals/ptlctl.h>
41
42 /*
43  * return:
44  *  0:      fail to insert (found identical)
45  *  1:      inserted
46  */
47 int insert_sort(gid_t *groups, int size, gid_t grp)
48 {
49         int i;
50         gid_t save;
51
52         for (i = 0; i < size; i++) {
53                 if (groups[i] == grp)
54                         return 0;
55                 if (groups[i] > grp)
56                         break;
57         }
58
59         for (; i <= size; i++) {
60                 save = groups[i];
61                 groups[i] = grp;
62                 grp = save;
63         }
64         return 1;
65 }
66
67 int get_groups_local(uid_t uid, gid_t *gid, int *ngroups, gid_t **groups)
68 {
69         int     maxgroups;
70         int     i, size = 0;
71         struct passwd *pw;
72         struct group  *gr;
73
74         *ngroups = 0;
75         *groups = NULL;
76         maxgroups = sysconf(_SC_NGROUPS_MAX);
77         *groups = malloc(maxgroups * sizeof(gid_t));
78         if (!*groups)
79                 return -ENOMEM;
80
81         pw = getpwuid(uid);
82         if (!pw)
83                 return -ENOENT;
84
85         *gid = pw->pw_gid;
86
87         while ((gr = getgrent())) {
88                 if (!gr->gr_mem)
89                         continue;
90                 for (i = 0; gr->gr_mem[i]; i++) {
91                         if (strcmp(gr->gr_mem[i], pw->pw_name))
92                                 continue;
93                         size += insert_sort(*groups, size, gr->gr_gid);
94                         break;
95                 }
96                 if (size == maxgroups)
97                         break;
98         }
99         endgrent();
100         *ngroups = size;
101         return 0;
102 }
103
104 #define LINEBUF_SIZE    (1024)
105 static char linebuf[LINEBUF_SIZE];
106
107 int readline(FILE *fp, char *buf, int bufsize)
108 {
109         char *p = buf;
110         int i = 0;
111
112         if (fgets(buf, bufsize, fp) == NULL)
113                 return -1;
114
115         while (*p) {
116                 if (*p == '#') {
117                         *p = '\0';
118                         break;
119                 }
120                 if (*p == '\n') {
121                         *p = '\0';
122                         break;
123                 }
124                 i++;
125                 p++;
126         }
127
128         return i;
129 }
130
131 #define IS_SPACE(c) ((c) == ' ' || (c) == '\t')
132
133 void remove_space_head(char **buf)
134 {
135         char *p = *buf;
136
137         while (IS_SPACE(*p))
138                 p++;
139
140         *buf = p;
141 }
142
143 void remove_space_tail(char **buf)
144 {
145         char *p = *buf;
146         char *spc = NULL;
147
148         while (*p) {
149                 if (!IS_SPACE(*p)) {
150                         if (spc) spc = NULL;
151                 } else
152                         if (!spc) spc = p;
153                 p++;
154         }
155
156         if (spc)
157                 *spc = '\0';
158 }
159
160 int get_next_uid_range(char **buf, uid_t *uid_range)
161 {
162         char *p = *buf;
163         char *comma, *sub;
164
165         remove_space_head(&p);
166         if (strlen(p) == 0)
167                 return -1;
168
169         comma = strchr(p, ',');
170         if (comma) {
171                 *comma = '\0';
172                 *buf = comma + 1;
173         } else
174                 *buf = p + strlen(p);
175
176         sub = strchr(p, '-');
177         if (!sub) {
178                 uid_range[0] = uid_range[1] = atoi(p);
179         } else {
180                 *sub++ = '\0';
181                 uid_range[0] = atoi(p);
182                 uid_range[1] = atoi(sub);
183         }
184
185         return 0;
186 }
187
188 /*
189  * return 0: ok
190  */
191 int remove_bracket(char **buf)
192 {
193         char *p = *buf;
194         char *p2;
195
196         if (*p++ != '[')
197                 return -1;
198
199         p2 = strchr(p, ']');
200         if (!p2)
201                 return -1;
202
203         *p2++ = '\0';
204         while (*p2) {
205                 if (*p2 != ' ' && *p2 != '\t')
206                         return -1;
207                 p2++;
208         }
209
210         remove_space_tail(&p);
211         *buf = p;
212         return 0;
213 }
214
215 /* return 0: found a match */
216 int search_uid(FILE *fp, uid_t uid)
217 {
218         char *p;
219         uid_t uid_range[2];
220         int rc;
221
222         while (1) {
223                 rc = readline(fp, linebuf, LINEBUF_SIZE);
224                 if (rc < 0)
225                         return rc;
226                 if (rc == 0)
227                         continue;
228
229                 p = linebuf;
230                 if (remove_bracket(&p))
231                         continue;
232
233                 while (get_next_uid_range(&p, uid_range) == 0) {
234                         if (uid >= uid_range[0] && uid <= uid_range[1]) {
235                                 return 0;
236                         }
237                 }
238                 continue;
239         }
240 }
241
242 static struct {
243         char   *name;
244         __u32   bit;
245 } perm_types[] =  {
246         {"setuid",      LSD_PERM_SETUID},
247         {"setgid",      LSD_PERM_SETGID},
248         {"setgrp",      LSD_PERM_SETGRP},
249 };
250 #define N_PERM_TYPES    (3)
251
252 int parse_perm(__u32 *perm, char *str)
253 {
254         char *p = str;
255         char *comma;
256         int i;
257
258         *perm = 0;
259
260         while (1) {
261                 p = str;
262                 comma = strchr(str, ',');
263                 if (comma) {
264                         *comma = '\0';
265                         str = comma + 1;
266                 }
267
268                 for (i = 0; i < N_PERM_TYPES; i++) {
269                         if (!strcasecmp(p, perm_types[i].name)) {
270                                 *perm |= perm_types[i].bit;
271                                 break;
272                         }
273                 }
274
275                 if (i >= N_PERM_TYPES) {
276                         printf("unkown perm type: %s\n", p);
277                         return -1;
278                 }
279
280                 if (!comma)
281                         break;
282         }
283         return 0;
284 }
285
286 int parse_nid(ptl_nid_t *nidp, char *nid_str)
287 {
288         if (!strcmp(nid_str, "*")) {
289                 *nidp = PTL_NID_ANY;
290                 return 0;
291         }
292
293         return ptl_parse_nid(nidp, nid_str);
294 }
295
296 int get_one_perm(FILE *fp, struct lsd_permission *perm)
297 {
298         char nid_str[256], perm_str[256];
299         int rc;
300
301 again:
302         rc = readline(fp, linebuf, LINEBUF_SIZE);
303         if (rc < 0)
304                 return rc;
305         if (rc == 0)
306                 goto again;
307
308         rc = sscanf(linebuf, "%s %s", nid_str, perm_str);
309         if (rc != 2)
310                 return -1;
311
312         if (parse_nid(&perm->nid, nid_str))
313                 return -1;
314
315         if (parse_perm(&perm->perm, perm_str))
316                 return -1;
317
318         perm->netid = 0;
319         return 0;
320 }
321
322 #define MAX_PERMS       (50)
323
324 int get_perms(FILE *fp, uid_t uid, int *nperms, struct lsd_permission **perms)
325 {
326         static struct lsd_permission _perms[MAX_PERMS];
327
328         if (search_uid(fp, uid))
329                 return -1;
330
331         *nperms = 0;
332         while (*nperms < MAX_PERMS) {
333                 if (get_one_perm(fp, &_perms[*nperms]))
334                         break;
335                 (*nperms)++;
336         }
337         *perms = _perms;
338         return 0;
339 }
340
341 void show_result(struct lsd_downcall_args *dc)
342 {
343         int i;
344
345         printf("err: %d, uid %u, gid %d\n"
346                "ngroups: %d\n",
347                dc->err, dc->uid, dc->gid, dc->ngroups);
348         for (i = 0; i < dc->ngroups; i++)
349                 printf("\t%d\n", dc->groups[i]);
350
351         printf("nperms: %d\n", dc->nperms);
352         for (i = 0; i < dc->nperms; i++)
353                 printf("\t: netid %u, nid "LPX64", bits %x\n", i,
354                         dc->perms[i].nid, dc->perms[i].perm);
355 }
356
357 #define log_msg(testing, fmt, args...)                  \
358         {                                               \
359                 if (testing)                            \
360                         printf(fmt, ## args);           \
361                 else                                    \
362                         syslog(LOG_ERR, fmt, ## args);  \
363         }
364
365 void usage(char *prog)
366 {
367         printf("Usage: %s [-t] uid\n", prog);
368         exit(1);
369 }
370
371 int main (int argc, char **argv)
372 {
373         char   *dc_name = "/proc/fs/lustre/mds/lsd_downcall";
374         int     dc_fd;
375         char   *conf_name = "/etc/lustre/lsd.conf";
376         FILE   *conf_fp;
377         struct lsd_downcall_args ioc_data;
378         extern char *optarg;
379         int     opt, testing = 0, rc;
380
381         while ((opt = getopt(argc, argv, "t")) != -1) {
382                 switch (opt) {
383                 case 't':
384                         testing = 1;
385                         break;
386                 default:
387                         usage(argv[0]);
388                 }
389         }
390
391         if (optind >= argc)
392                 usage(argv[0]);
393
394         memset(&ioc_data, 0, sizeof(ioc_data));
395         ioc_data.uid = atoi(argv[optind]);
396
397         /* read user/group database */
398         ioc_data.err = get_groups_local(ioc_data.uid, &ioc_data.gid,
399                                         (int *)&ioc_data.ngroups,
400                                         &ioc_data.groups);
401         if (ioc_data.err)
402                 goto do_downcall;
403
404         /* read lsd config database */
405         conf_fp = fopen(conf_name, "r");
406         if (conf_fp) {
407                 get_perms(conf_fp, ioc_data.uid,
408                           (int *)&ioc_data.nperms,
409                           &ioc_data.perms);
410                 fclose(conf_fp);
411         }
412
413
414 do_downcall:
415         if (testing) {
416                 show_result(&ioc_data);
417                 return 0;
418         } else {
419                 dc_fd = open(dc_name, O_WRONLY);
420                 if (dc_fd < 0) {
421                         log_msg(testing, "can't open device %s: %s\n",
422                                 dc_name, strerror(errno));
423
424                         return -errno;
425                 }
426
427                 rc = write(dc_fd, &ioc_data, sizeof(ioc_data));
428                 if (rc != sizeof(ioc_data)) {
429                         log_msg(testing, "partial write ret %d: %s\n",
430                                 rc, strerror(errno));
431                 }
432
433                 return (rc != sizeof(ioc_data));
434         }
435 }