Whamcloud - gitweb
LU-10003 lnet: Update lctl ping to work with large NIDs
[fs/lustre-release.git] / lustre / utils / l_getidentity.c
index df80258..22c93c0 100644 (file)
@@ -27,7 +27,6 @@
  */
 /*
  * This file is part of Lustre, http://www.lustre.org/
- * Lustre is a trademark of Sun Microsystems, Inc.
  */
 
 #include <stdbool.h>
 #include <stddef.h>
 #include <libgen.h>
 #include <syslog.h>
+#include <sys/time.h>
+#include <limits.h>
+#include <ctype.h>
+#include <nss.h>
+#include <dlfcn.h>
 
 #include <libcfs/util/param.h>
 #include <linux/lnet/nidstr.h>
 #include <linux/lustre/lustre_idl.h>
 
 #define PERM_PATHNAME "/etc/lustre/perm.conf"
+#define LUSTRE_PASSWD "/etc/lustre/passwd"
+#define LUSTRE_GROUP  "/etc/lustre/group"
+
+#define L_GETIDENTITY_LOOKUP_CMD "lookup"
+#define NSS_MODULES_MAX_NR 8
+#define NSS_MODULE_NAME_SIZE 32
+#define NSS_SYMBOL_NAME_LEN_MAX 256
+
+static int nss_pw_buf_len;
+static void *nss_pw_buf;
+static int nss_grent_buf_len;
+static void *nss_grent_buf;
+static int g_n_nss_modules;
+static int grent_mod_no = -1;
+
+struct nss_module {
+       char name[NSS_MODULE_NAME_SIZE];
+       int (*getpwuid)(struct nss_module *mod, uid_t, struct passwd *pwd);
+       int (*getgrent)(struct nss_module *mod, struct group *result);
+       void (*endgrent)(struct nss_module *mod);
+       void (*fini)(struct nss_module *mod);
+
+       union {
+               struct {
+                       void *l_ptr;
+                       int  (*l_getpwuid)(uid_t, struct passwd *pwd,
+                                          char *buffer, size_t buflen,
+                                          int *errnop);
+                       int  (*l_getgrent)(struct group *result, char *buffer,
+                                          size_t buflen, int *errnop);
+                       int  (*l_endgrent)(void);
+               } lib;
+               struct {
+                       FILE *f_passwd;
+                       FILE *f_group;
+               } files;
+       } u;
+};
+
+static struct nss_module g_nss_modules[NSS_MODULES_MAX_NR];
+
+#define NSS_LIB_NAME_PATTERN "libnss_%s.so.2"
 
 /*
  * permission file format is like this:
@@ -75,15 +121,11 @@ static void usage(void)
                "\nusage: %s {-d|mdtname} {uid}\n"
                "Normally invoked as an upcall from Lustre, set via:\n"
                "lctl set_param mdt.${mdtname}.identity_upcall={path to upcall}\n"
-               "\t-d: debug, print values to stdout instead of Lustre\n",
+               "\t-d: debug, print values to stdout instead of Lustre\n"
+               "\tNSS support enabled\n",
                progname);
 }
 
-static int compare_u32(const void *v1, const void *v2)
-{
-       return *(__u32 *)v1 - *(__u32 *)v2;
-}
-
 static void errlog(const char *fmt, ...)
 {
        va_list args;
@@ -97,6 +139,335 @@ static void errlog(const char *fmt, ...)
        closelog();
 }
 
+static int compare_gids(const void *v1, const void *v2)
+{
+       return (*(gid_t *)v1 - *(gid_t *)v2);
+}
+
+/** getpwuid() replacement */
+static struct passwd *getpwuid_nss(uid_t uid)
+{
+       static struct passwd pw;
+       int i;
+
+       for (i = 0; i < g_n_nss_modules; i++) {
+               struct nss_module *mod = g_nss_modules + i;
+
+               if (mod->getpwuid(mod, uid, &pw) == 0)
+                       return &pw;
+       }
+       return NULL;
+}
+
+/**
+ * getgrent() replacement.
+ *  simulate getgrent(3) across nss modules
+ */
+static struct group *getgrent_nss(void)
+{
+       static struct group grp;
+
+       if (grent_mod_no < 0)
+               grent_mod_no = 0;
+
+       while (grent_mod_no < g_n_nss_modules) {
+               struct nss_module *mod = g_nss_modules + grent_mod_no;
+
+               if (mod->getgrent(mod, &grp) == 0)
+                       return &grp;
+               mod->endgrent(mod);
+               grent_mod_no++;
+       }
+       return NULL;
+}
+
+/** endgrent() replacement */
+static void endgrent_nss(void)
+{
+       if (grent_mod_no < g_n_nss_modules
+               && grent_mod_no >= 0) {
+               struct nss_module *mod = g_nss_modules+grent_mod_no;
+
+               mod->endgrent(mod);
+       }
+       grent_mod_no = -1;
+}
+
+/** lookup symbol in dynamically loaded nss module */
+static void *get_nss_sym(struct nss_module *mod, const char *op)
+{
+       void *res;
+       int bytes;
+       char symbuf[NSS_SYMBOL_NAME_LEN_MAX];
+
+       bytes = snprintf(symbuf, NSS_SYMBOL_NAME_LEN_MAX - 1, "_nss_%s_%s",
+                       mod->name, op);
+       if (bytes >= NSS_SYMBOL_NAME_LEN_MAX - 1) {
+               errlog("symbol name too long\n");
+               return NULL;
+       }
+       res = dlsym(mod->u.lib.l_ptr, symbuf);
+       if (res == NULL)
+               errlog("cannot find symbol %s in nss module \"%s\": %s\n",
+                       symbuf, mod->name, dlerror());
+       return res;
+}
+
+/** allocate bigger buffer */
+static void enlarge_nss_buffer(void **buf, int *bufsize)
+{
+       free(*buf);
+       *bufsize = *bufsize * 2;
+       *buf = malloc(*bufsize);
+       if (*buf == NULL) {
+               errlog("no memory to allocate bigger buffer of %d bytes\n",
+                       *bufsize);
+               exit(-1);
+       }
+}
+
+static int getpwuid_nss_lib(struct nss_module *nss, uid_t uid,
+                           struct passwd *pw)
+{
+       int tmp_errno, err;
+
+       while (1) {
+               err = nss->u.lib.l_getpwuid(uid, pw, nss_pw_buf,
+                               nss_pw_buf_len, &tmp_errno);
+               if (err == NSS_STATUS_TRYAGAIN) {
+                       if (tmp_errno == ERANGE) {
+                               /* buffer too small */
+                               enlarge_nss_buffer(&nss_pw_buf,
+                                               &nss_pw_buf_len);
+                       }
+                       continue;
+               }
+               break;
+       }
+       if (err == NSS_STATUS_SUCCESS)
+               return 0;
+       return -ENOENT;
+}
+
+static int getgrent_nss_lib(struct nss_module *nss, struct group *gr)
+{
+       int tmp_errno, err;
+
+       while (1) {
+               err = nss->u.lib.l_getgrent(gr, nss_grent_buf,
+                               nss_grent_buf_len, &tmp_errno);
+               if (err == NSS_STATUS_TRYAGAIN) {
+                       if (tmp_errno == ERANGE) {
+                               /* buffer too small */
+                               enlarge_nss_buffer(&nss_grent_buf,
+                                               &nss_grent_buf_len);
+                       }
+                       continue;
+               }
+               break;
+       }
+       if (err == NSS_STATUS_SUCCESS)
+               return 0;
+       return -ENOENT;
+}
+
+static void endgrent_nss_lib(struct nss_module *mod)
+{
+       mod->u.lib.l_endgrent();
+}
+
+/** destroy a "shared lib" nss module */
+static void fini_nss_lib_module(struct nss_module *mod)
+{
+       if (mod->u.lib.l_ptr)
+               dlclose(mod->u.lib.l_ptr);
+}
+
+/** load and initialize a "shared lib" nss module */
+static int init_nss_lib_module(struct nss_module *mod, char *name)
+{
+       char lib_file_name[sizeof(NSS_LIB_NAME_PATTERN) + sizeof(mod->name)];
+
+       if (strlen(name) >= sizeof(mod->name)) {
+               errlog("module name (%s) too long\n", name);
+               exit(1);
+       }
+
+       strncpy(mod->name, name, sizeof(mod->name));
+       mod->name[sizeof(mod->name) - 1] = '\0';
+
+       snprintf(lib_file_name, sizeof(lib_file_name), NSS_LIB_NAME_PATTERN,
+                name);
+
+       mod->getpwuid = getpwuid_nss_lib;
+       mod->getgrent = getgrent_nss_lib;
+       mod->endgrent = endgrent_nss_lib;
+       mod->fini = fini_nss_lib_module;
+
+       mod->u.lib.l_ptr = dlopen(lib_file_name, RTLD_NOW);
+       if (mod->u.lib.l_ptr == NULL) {
+               errlog("dl error %s\n", dlerror());
+               exit(1);
+       }
+       mod->u.lib.l_getpwuid = get_nss_sym(mod, "getpwuid_r");
+       if (mod->getpwuid == NULL)
+               exit(1);
+
+       mod->u.lib.l_getgrent = get_nss_sym(mod, "getgrent_r");
+       if (mod->getgrent == NULL)
+               exit(1);
+
+       mod->u.lib.l_endgrent = get_nss_sym(mod, "endgrent");
+       if (mod->endgrent == NULL)
+               exit(1);
+
+       return 0;
+}
+
+static void fini_lustre_nss_module(struct nss_module *mod)
+{
+       if (mod->u.files.f_passwd)
+               fclose(mod->u.files.f_passwd);
+       if (mod->u.files.f_group)
+               fclose(mod->u.files.f_group);
+}
+
+static int getpwuid_lustre_nss(struct nss_module *mod, uid_t uid,
+                             struct passwd *pw)
+{
+       struct passwd *pos;
+
+       while ((pos = fgetpwent(mod->u.files.f_passwd)) != NULL) {
+               if (pos->pw_uid == uid) {
+                       *pw = *pos;
+                       return 0;
+               }
+       }
+       return -1;
+}
+
+static int getgrent_lustre_nss(struct nss_module *mod, struct group *gr)
+{
+       struct group *pos;
+
+       pos = fgetgrent(mod->u.files.f_group);
+       if (pos) {
+               *gr = *pos;
+               return 0;
+       }
+       return 1;
+}
+
+static void endgrent_lustre_nss(struct nss_module *mod)
+{
+}
+
+/** initialize module to access local /etc/lustre/passwd,group files */
+static int init_lustre_module(struct nss_module *mod)
+{
+       mod->fini = fini_lustre_nss_module;
+       mod->getpwuid = getpwuid_lustre_nss;
+       mod->getgrent = getgrent_lustre_nss;
+       mod->endgrent = endgrent_lustre_nss;
+
+       mod->u.files.f_passwd = fopen(LUSTRE_PASSWD, "r");
+       if (mod->u.files.f_passwd == NULL)
+               exit(1);
+
+       mod->u.files.f_group = fopen(LUSTRE_GROUP, "r");
+       if (mod->u.files.f_group == NULL)
+               exit(1);
+
+       snprintf(mod->name, sizeof(mod->name), "lustre");
+       return 0;
+}
+
+/** load and initialize the "nss" system */
+static void init_nss(void)
+{
+       nss_pw_buf_len = sysconf(_SC_GETPW_R_SIZE_MAX);
+       if (nss_pw_buf_len == -1) {
+               perror("sysconf");
+               exit(1);
+       }
+       nss_pw_buf = malloc(nss_pw_buf_len);
+       if (nss_pw_buf == NULL) {
+               perror("pw buffer allocation");
+               exit(1);
+       }
+
+       nss_grent_buf_len = sysconf(_SC_GETGR_R_SIZE_MAX);
+       if (nss_grent_buf_len == -1) {
+               perror("sysconf");
+               exit(1);
+       }
+       nss_grent_buf = malloc(nss_grent_buf_len);
+       if (nss_grent_buf == NULL) {
+               perror("grent buffer allocation");
+               exit(1);
+       }
+}
+
+/** unload "nss" */
+static void fini_nss(void)
+{
+       int i;
+
+       for (i = 0; i < g_n_nss_modules; i++) {
+               struct nss_module *mod = g_nss_modules + i;
+
+               mod->fini(mod);
+       }
+
+       free(nss_pw_buf);
+       free(nss_grent_buf);
+}
+
+/** get supplementary group info and fill downcall data */
+static int get_groups_nss(struct identity_downcall_data *data,
+                         unsigned int maxgroups)
+{
+       struct passwd *pw;
+       struct group *gr;
+       gid_t *groups;
+       unsigned int ngroups = 0;
+       char *pw_name;
+       int i;
+
+       pw = getpwuid_nss(data->idd_uid);
+       if (pw == NULL) {
+               data->idd_err = errno ? errno : EIDRM;
+               errlog("no such user %u\n", data->idd_uid);
+               return -1;
+       }
+
+       data->idd_gid = pw->pw_gid;
+       pw_name = strdup(pw->pw_name);
+       groups = data->idd_groups;
+
+       while ((gr = getgrent_nss()) != NULL && ngroups < maxgroups) {
+               if (gr->gr_gid == pw->pw_gid)
+                       continue;
+               if (!gr->gr_mem)
+                       continue;
+               for (i = 0; gr->gr_mem[i]; i++) {
+                       if (!strcmp(gr->gr_mem[i], pw_name)) {
+                               groups[ngroups++] = gr->gr_gid;
+                               break;
+                       }
+               }
+       }
+
+       endgrent_nss();
+
+       if (ngroups > 0)
+               qsort(groups, ngroups, sizeof(*groups), compare_gids);
+       data->idd_ngroups = ngroups;
+
+       free(pw_name);
+       return 0;
+}
+
 int get_groups_local(struct identity_downcall_data *data,
                     unsigned int maxgroups)
 {
@@ -144,13 +515,21 @@ int get_groups_local(struct identity_downcall_data *data,
                        groups[ngroups++] = groups_tmp[i];
 
        if (ngroups > 0)
-               qsort(groups, ngroups, sizeof(*groups), compare_u32);
+               qsort(groups, ngroups, sizeof(*groups), compare_gids);
        data->idd_ngroups = ngroups;
 
        free(groups_tmp);
        return 0;
 }
 
+int get_groups_common(struct identity_downcall_data *data,
+                     unsigned int maxgroups)
+{
+       if (g_n_nss_modules)
+               return get_groups_nss(data, maxgroups);
+       return get_groups_local(data, maxgroups);
+}
+
 static inline int comment_line(char *line)
 {
        char *p = line;
@@ -177,12 +556,12 @@ static inline int match_uid(uid_t uid, const char *str)
        return (uid == uid2);
 }
 
-typedef struct {
+struct perm_type {
        char *name;
        __u32 bit;
-} perm_type_t;
+};
 
-static perm_type_t perm_types[] = {
+static struct perm_type perm_types[] = {
        { "setuid", CFS_SETUID_PERM },
        { "setgid", CFS_SETGID_PERM },
        { "setgrp", CFS_SETGRP_PERM },
@@ -191,7 +570,7 @@ static perm_type_t perm_types[] = {
        { 0 }
 };
 
-static perm_type_t noperm_types[] = {
+static struct perm_type noperm_types[] = {
        { "nosetuid", CFS_SETUID_PERM },
        { "nosetgid", CFS_SETGID_PERM },
        { "nosetgrp", CFS_SETGRP_PERM },
@@ -204,7 +583,7 @@ int parse_perm(__u32 *perm, __u32 *noperm, char *str)
 {
        char *start, *end;
        char name[64];
-       perm_type_t *pt;
+       struct perm_type *pt;
 
        *perm = 0;
        *noperm = 0;
@@ -344,10 +723,89 @@ parse_perm_line(struct identity_downcall_data *data, char *line, size_t size)
        return 0;
 }
 
+static char *striml(char *s)
+{
+       while (isspace(*s))
+               s++;
+       return s;
+}
+
+static void check_new_nss_module(struct nss_module *mod)
+{
+       int i;
+
+       for (i = 0; i < g_n_nss_modules; i++) {
+               struct nss_module *pos = g_nss_modules + i;
+
+               if (!strcmp(mod->name, pos->name)) {
+                       errlog("attempt to initialize \"%s\" module twice\n",
+                               pos->name);
+                       exit(-1);
+               }
+       }
+}
+
+/**
+ * Check and parse lookup db config line.
+ * File should start with 'lookup' followed by the modules
+ * to be loaded, for example:
+ *
+ *  [/etc/lustre/perm.conf]
+ *  lookup lustre ldap
+ *
+ * Should search, in order, first found wins:
+ *    lustre [/etc/lustre/passwd and /etc/lustre/group]
+ *    ldap
+ *
+ * Other common nss modules: nis sss db files
+ * Since historically 'files' has been used exclusively
+ *  to mean 'lustre auth files' and disabled using local auth
+ *  via libnss_files users must select 'nss_files' to explicitly
+ *  enable libnss_files, which is an uncommon configuration.
+ */
+static int lookup_db_line_nss(char *line)
+{
+       char *p, *tok;
+       int ret = 0;
+
+       p = striml(line);
+       if (strncmp(p, L_GETIDENTITY_LOOKUP_CMD,
+           sizeof(L_GETIDENTITY_LOOKUP_CMD) - 1))
+               return -EAGAIN;
+
+       tok = strtok(p, " \t");
+       if (tok == NULL || strcmp(tok, L_GETIDENTITY_LOOKUP_CMD))
+               return -EIO;
+
+       while ((tok = strtok(NULL, " \t\n")) != NULL) {
+               struct nss_module *newmod = NULL;
+
+               if (g_n_nss_modules < NSS_MODULES_MAX_NR)
+                       newmod = &g_nss_modules[g_n_nss_modules];
+               else
+                       return -ERANGE;
+
+               if (!strcmp(tok, "files") || !strcmp(tok, "lustre"))
+                       ret = init_lustre_module(newmod);
+               else if (!strcmp(tok, "nss_files"))
+                       ret = init_nss_lib_module(newmod, "files");
+               else
+                       ret = init_nss_lib_module(newmod, tok);
+
+               if (ret)
+                       break;
+               check_new_nss_module(newmod);
+               g_n_nss_modules++;
+       }
+
+       return ret;
+}
+
 int get_perms(struct identity_downcall_data *data)
 {
        FILE *fp;
        char line[PATH_MAX];
+       int ret;
 
        fp = fopen(PERM_PATHNAME, "r");
        if (!fp) {
@@ -362,7 +820,9 @@ int get_perms(struct identity_downcall_data *data)
        while (fgets(line, sizeof(line), fp)) {
                if (comment_line(line))
                        continue;
-
+               ret = lookup_db_line_nss(line); /* lookup parsed */
+               if (ret == 0)
+                       continue;
                if (parse_perm_line(data, line, sizeof(line))) {
                        errlog("parse line %s failed!\n", line);
                        data->idd_err = EINVAL;
@@ -402,55 +862,78 @@ static void show_result(struct identity_downcall_data *data)
        printf("\n");
 }
 
+#define difftime(a, b)                                 \
+       ((a).tv_sec - (b).tv_sec +                      \
+        ((a).tv_usec - (b).tv_usec) / 1000000.0)
+
 int main(int argc, char **argv)
 {
        char *end;
        struct identity_downcall_data *data = NULL;
        glob_t path;
        unsigned long uid;
+       struct timeval start, idgot, fini;
        int fd, rc = -EINVAL, size, maxgroups;
+       bool alreadyfailed = false;
 
        progname = basename(argv[0]);
        if (argc != 3) {
                usage();
-               goto out;
+               goto out_no_nss;
        }
 
+       errno = 0;
        uid = strtoul(argv[2], &end, 0);
-       if (*end) {
+       if (*end != '\0' || end == argv[2] || errno != 0) {
                errlog("%s: invalid uid '%s'\n", progname, argv[2]);
-               goto out;
+               goto out_no_nss;
        }
+       gettimeofday(&start, NULL);
 
        maxgroups = sysconf(_SC_NGROUPS_MAX);
        if (maxgroups > NGROUPS_MAX)
                maxgroups = NGROUPS_MAX;
        if (maxgroups == -1) {
                rc = -EINVAL;
-               goto out;
+               goto out_no_nss;
        }
 
+retry:
        size = offsetof(struct identity_downcall_data, idd_groups[maxgroups]);
        data = malloc(size);
        if (!data) {
                errlog("malloc identity downcall data(%d) failed!\n", size);
+               if (!alreadyfailed) {
+                       alreadyfailed = true;
+                       goto retry;
+               }
                rc = -ENOMEM;
-               goto out;
+               goto out_no_nss;
        }
 
        memset(data, 0, size);
        data->idd_magic = IDENTITY_DOWNCALL_MAGIC;
        data->idd_uid = uid;
+
+       init_nss();
+
+       /* read permission database and/or load nss modules
+        * rc is -1 only when file exists and is not readable or
+        * content has format / syntax errors
+        */
+       rc = get_perms(data);
+       if (rc)
+               goto downcall;
+
        /* get groups for uid */
-       rc = get_groups_local(data, maxgroups);
+       rc = get_groups_common(data, maxgroups);
        if (rc)
                goto downcall;
 
        size = offsetof(struct identity_downcall_data,
                        idd_groups[data->idd_ngroups]);
-       /* read permission database */
-       rc = get_perms(data);
 
+       gettimeofday(&idgot, NULL);
 downcall:
        if (strcmp(argv[1], "-d") == 0 || getenv("L_GETIDENTITY_TEST")) {
                show_result(data);
@@ -473,17 +956,32 @@ downcall:
        }
 
        rc = write(fd, data, size);
+       gettimeofday(&fini, NULL);
        close(fd);
        if (rc != size) {
                errlog("partial write ret %d: %s\n", rc, strerror(errno));
+               if (!alreadyfailed) {
+                       alreadyfailed = true;
+                       cfs_free_param_data(&path);
+                       if (data)
+                               free(data);
+                       goto retry;
+               }
                rc = -1;
        } else {
                rc = 0;
        }
+       /* log if it takes more than 20 second to avoid rate limite */
+       if (rc || difftime(fini, start) > 20)
+               errlog("get identity for uid %lu start time %ld.%06ld got time %ld.%06ld end time %ld.%06ld: rc = %d\n",
+                      uid, start.tv_sec, start.tv_usec, idgot.tv_sec,
+                      idgot.tv_usec, fini.tv_sec, fini.tv_usec, rc);
 
 out_params:
        cfs_free_param_data(&path);
 out:
+       fini_nss();
+out_no_nss:
        if (data)
                free(data);
        return rc;