/* modprobe.c: insert a module into the kernel, intelligently.
    Copyright (C) 2001  Rusty Russell.
    Copyright (C) 2002  Rusty Russell, IBM Corporation.

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/
#include <sys/utsname.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <fcntl.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <dirent.h>
#include <limits.h>
#include <elf.h>
#include <getopt.h>
#include <fnmatch.h>
#include <asm/unistd.h>

#include "backwards_compat.c"

#include "mod_types.h"

#define MODULE_DIR "/lib/modules/%s/kernel/"

/* We decide when we hit the first module whether we are 32 or 64-bit,
   and set this. */
static struct mod_ops *mops;

static void fatal(const char *fmt, ...)
__attribute__ ((noreturn, format (printf, 1, 2)));

static void fatal(const char *fmt, ...)
{
	va_list arglist;

	fprintf(stderr, "FATAL: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);

	exit(1);
}

static void warn(const char *fmt, ...)
__attribute__ ((format (printf, 1, 2)));

static void warn(const char *fmt, ...)
{
	va_list arglist;

	fprintf(stderr, "WARNING: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);
}

static struct alias *add_aliases(int fd,
				 unsigned long shdroff,
				 unsigned int num_secs,
				 unsigned int secnamesec,
				 struct module *mod,
				 struct alias *last);

static int need_symbol(unsigned int order,
		       const char *name,
		       struct module *modules,
		       const char *modname);

#include "mod32.c"
#include "mod64.c"

static void print_usage(const char *progname)
{
	fprintf(stderr,
		"Usage: %s [--verbose|--version|--config] filename options\n",
		progname);
	exit(1);
}

static int ends_in(const char *name, const char *ext)
{
	unsigned int namelen, extlen, i;

	/* Grab lengths */
	namelen = strlen(name);
	extlen = strlen(ext);

	if (namelen < extlen) return 0;

	/* Look backwards */
	for (i = 0; i < extlen; i++)
		if (name[namelen - i] != ext[extlen - i]) return 0;

	return 1;
}

/* FIXME: Loop detect. */
static struct alias *find_alias(const char *name, struct alias *aliases)
{
	struct alias *i;

	for (i = aliases; i; i = i->next)
		if (fnmatch(i->name, name, 0) == 0) {
			/* Chase down alias to aliases */
			while (i->alias)
				i = i->alias;
			return i;
		}

	return NULL;
}

static struct alias *add_alias(struct alias *last,
			       const char *aliasname,
			       struct module *mod,
			       struct alias *alias,
			       int linenum)
{
	struct alias *newalias;

	newalias = malloc(sizeof *newalias + strlen(aliasname) + 1);
	strcpy(newalias->name, aliasname);
	newalias->config_line = linenum;
	newalias->module = mod;
	newalias->alias = alias;

	newalias->next = last;
	return newalias;
}

static struct alias *add_aliases(int fd,
				 unsigned long shdroff,
				 unsigned int num_secs,
				 unsigned int secnamesec,
				 struct module *mod,
				 struct alias *last)
{
	char *aliases;
	unsigned long size, i;

	aliases = mops->load_section(fd, shdroff, num_secs, secnamesec,
				     ".modalias",  &size);
	if (aliases == (void *)-1) {
		warn("Error loading aliases from module %s\n", mod->name);
		return last;
	}

	if (aliases) {
		for (i = 0; i < size; i += strlen(aliases+i)+1)
			last = add_alias(last, aliases+i, mod, NULL, 0);
		free(aliases);
	}
	return last;
}

static struct module *add_module(const char *dirname, const char *entry,
				 struct module *last, struct alias **aliases)
{
	int fd;
	struct module *new;
	char pathname[strlen(dirname) + strlen(entry) + 1];

	new = malloc(sizeof(*new) + strlen(entry) + 1);
	strcpy(new->name, entry);
	/* Truncate extension */
	new->name[strlen(new->name) - strlen(MODULE_EXTENSION)] = '\0';
	new->order = 0;
	new->next = last;

	sprintf(pathname, "%s%s", dirname, entry);
	fd = open(pathname, O_RDONLY);
	if (fd < 0) {
		warn("Can't read module %s: %s\n", pathname, strerror(errno));
		free(new);
		return last;
	}

	/* First call initializes this. */
	if (!mops) {
		/* "\177ELF" <byte> where byte = 001 for 32-bit, 002 for 64 */
		char ident[EI_NIDENT];

		if (read(fd, ident, EI_NIDENT) != EI_NIDENT) {
			warn("Can't read module %s elf identifier: %s\n",
			     pathname, strerror(errno));
			free(new);
			return last;
		}
		switch (ident[EI_CLASS]) {
		case ELFCLASS32:
			mops = &mod32_ops;
			break;
		case ELFCLASS64:
			mops = &mod64_ops;
			break;
		default:
			warn("Module %s has elf unknown identifier %i\n",
			     pathname, ident[EI_CLASS]);
			free(new);
			return last;
		}
		lseek(fd, 0, SEEK_SET);
	}

	return mops->add_module(fd, new, last, aliases);
}

static struct module *load_all_modules(const char *dirname,
				       struct alias **aliases)
{
	struct module *mods = NULL;
	struct dirent *dirent;
	DIR *dir;

	dir = opendir(dirname);
	if (dir) {
		while ((dirent = readdir(dir)) != NULL) {
			/* Is it a .o file? */
			if (ends_in(dirent->d_name, MODULE_EXTENSION))
				mods = add_module(dirname, dirent->d_name,
						  mods, aliases);
		}
		closedir(dir);
	}
	return mods;
}

static int need_symbol(unsigned int order,
		       const char *name,
		       struct module *modules,
		       const char *modname)
{
	struct module *m;
	struct module *found = NULL;

	for (m = modules; m; m = m->next) {
		unsigned int i;
		for (i = 0; i < m->num_exports; i++) {
			if (mops->export_name_cmp(m, i, name) == 0) {
				if (found) {
					warn("%s supplied by %s and %s:"
					     " picking neither\n",
					     name, m->name, found->name);
					/* Noone chosen */
					return 0;
				}
				if (modname)
					printf("%s needs %s: found in %s\n",
					       modname, name, m->name);
				found = m;
				/* If we didn't need to load it
                                   already, we do now. */
				found->order = order;
			}
		}
	}
	if (found) return 1;
	else return 0;
}

/* We use error numbers in a loose translation... */
static const char *moderror(int err)
{
	switch (err) {
	case ENOEXEC:
		return "Invalid module format";
	case ENOENT:
		return "Unknown symbol in module";
	default:
		return strerror(err);
	}
}

/* Actually do the insert. */
static void insmod(const char *dirname,
		   const char *filename,
		   const char *options,
		   int dont_fail)
{
	int fd, ret;
	struct stat st;
	unsigned long len;
	void *map;
	char modpath[strlen(dirname) + strlen(filename)
		    + sizeof(MODULE_EXTENSION)];

	/* FIXME: Look in module for name. --RR */
	sprintf(modpath, "%s%s%s", dirname, filename, MODULE_EXTENSION);

	/* Now, it may already be loaded: check /proc/modules */
	fd = open("/proc/modules", O_RDONLY);
	if (fd < 0) {
		warn("Cannot open /proc/modules:"
		     " assuming no modules loaded.\n");
	} else {
		char *buf;
		unsigned int fill, size = 1024;

		buf = malloc(size+1);
		buf[0] = '\n';
		fill = 1;

		while ((ret = read(fd, buf+fill, size - fill)) > 0) {
			size *= 2;
			buf = realloc(buf, size+1);
			fill += ret;
		}
		if (ret < 0)
			fatal("Error reading /proc/modules: %s\n",
			      strerror(errno));
		else {
			char *ptr;
			unsigned int i;
			char name_with_ret[strlen(strrchr(modpath, '/')) + 2];

			buf[fill+1] = '\0';
			/* Must appear at start of line. */
			name_with_ret[0] = '\n';
			strcpy(name_with_ret + 1, strrchr(modpath, '/') + 1);
			
			/* Convert to underscores */
			for (i = 0; name_with_ret[i]; i++)
				if (name_with_ret[i] == '-')
					name_with_ret[i] = '_';

			for (ptr = buf;
			     (ptr = strstr(ptr, name_with_ret)) != NULL;
			     ptr++) {
				if (!isspace(ptr[strlen(name_with_ret)]))
					continue;
				/* Found: don't try to load again */
				if (dont_fail)
					fatal("Module %s already loaded\n",
					      name_with_ret+1);
				close(fd);
				free(buf);
				return;
			}
		}
		close(fd);
		free(buf);
	}
	close(fd);

	fd = open(modpath, O_RDONLY);
	if (fd < 0)
		fatal("Could not open `%s': %s\n", modpath, strerror(errno));

	fstat(fd, &st);
	len = st.st_size;
	map = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
	if (map == MAP_FAILED)
		fatal("Can't map `%s': %s\n", modpath, strerror(errno));

	ret = syscall(__NR_init_module, map, len, options);
	if (ret != 0) {
		if (dont_fail)
			fatal("Error inserting %s: %s\n",
			      modpath, moderror(errno));
		else
			warn("Error inserting %s: %s\n",
			     modpath, moderror(errno));
	}
	close(fd);
}

/* Read one line into the buffer */
static char *read_line(FILE *f)
{
	int size = 80;
	char *result = malloc(size);

	result[0] = '\0';
	while (fgets(result + strlen(result), size - strlen(result), f)) {
		char *nl = strchr(result, '\n');
		if (nl) {
			*nl = '\0';
			return result;
		}
		size *= 2;
		result = realloc(result, size);
	}
	if (strlen(result)) {
		warn("Unexpected error reading config file: %s\n",
		     strerror(errno));
		return result;
	}
	free(result);
	return NULL;
}

static struct module *find_module(const char *name, struct module *modules)
{
	struct module *m;

	for (m = modules; m; m = m->next)
		if (strcmp(name, m->name) == 0)
			return m;
	return NULL;
}

static char *get_word(const char **line)
{
	size_t len;
	char *word = NULL;

	*line += strspn(*line, "\t ");
	len = strcspn(*line, "\t ");
	if (**line) {
		word = strdup(*line);
		word[len] = '\0';
	}
	*line += len;

	return word;
}

/* alias newname oldname */
static struct alias *parse_alias(struct alias *last, const char *line,
				 int linenum,
				 struct module *modules)
{
	char *newname, *oldname;
	struct alias *aliasto;
	struct module *moduleto = NULL;

	newname = get_word(&line);
	if (!newname) {
		warn("Config line %i missing first arg\n", linenum);
		return last;
	}
	oldname = get_word(&line);
	if (!oldname) {
		warn("Config line %i missing second arg\n", linenum);
		free(newname);
		return last;
	}

	aliasto = find_alias(oldname, last);
	if (aliasto || (moduleto = find_module(oldname, modules)))
		last = add_alias(last, newname, moduleto, aliasto, linenum);
	else
		warn("Config line %i aliases to unknown module %s\n",
		     linenum, oldname);

	free(oldname);
	free(newname);
	return last;
}

static int keyword(const char *line, const char *keyword)
{
	size_t len = strcspn(line, "\t ");

	if (len == strlen(keyword)
	    && memcmp(line, keyword, len) == 0)
		return 1;
	return 0;
}

/* Simple format, ignore lines starting with #, one command per line */
static void load_config_file(const char *filename, int mustload,
			     struct module *modules,
			     struct alias **aliases)
{
	FILE *cfile;
	char *line;
	int linenum = 1;

	cfile = fopen(filename, "r");
	if (!cfile) {
		if (mustload)
			fatal("Failed to open config file %s: %s\n",
			      filename, strerror(errno));
		return;
	}

	while ((line = read_line(cfile)) != NULL) {
		size_t len;
		len = strspn(line, "\t ");
		/* Comment or blank? */
		if (line[len] == '#' || line[len] == '\0')
			goto next;
		else if (keyword(line+len, "alias"))
			*aliases = parse_alias(*aliases, line+len+strlen("alias"),
					       linenum, modules);
		else
			fatal("Unknown line %s in config file\n", line);
	next:
		free(line);
		linenum++;
	}
}

static char *resolve_alias(const char *name,
			   struct module *modules,
			   struct alias *aliases)
{
	struct module *m;
	struct alias *a;

	a = find_alias(name, aliases);
	if (a)
		return a->module->name;

	m = find_module(name, modules);
	if (m)
		return m->name;

	fatal("Could not find a module or alias named %s\n",
	      name);
}

static void load_symbol(const char *dirname,
			const char *symname,
			struct module *modules,
			const char *options,
			int verbose)
{
	struct module *i;

	if (!need_symbol(1, symname, modules, verbose?"symbol request":NULL))
		fatal("Could not find module with symbol %s\n", symname);

	for (i = modules; i; i = i->next) {
		if (i->order) {
			if (verbose) printf("Loading %s\n", i->name);
			insmod(dirname, i->name, options, 1);
		}
	}
}

static void load(const char *dirname, const char *modname,
		 struct module *modules,
		 struct alias *aliases,
		 const char *options,
		 int verbose)
{
	unsigned int order;
	struct module *i;
	char *realname;

	realname = resolve_alias(modname, modules, aliases);

	order = 1;
	if (mops->get_deps(order, dirname, realname, modules, verbose)) {
		/* We need some other modules. */
		int more_needed;

		do {
			more_needed = 0;
			for (i = modules; i; i = i->next) {
				if (i->order == order) {
					if (mops->get_deps(order + 1, dirname,
							   i->name,
							   modules, verbose))
						more_needed = 1;
				}
			}
			order++;
		} while (more_needed);
	}

	/* Now, walk back through orders, loading */
	for (; order > 0; order--) {
		for (i = modules; i; i = i->next) {
			if (i->order == order) {
				if (verbose) printf("Loading %s\n", i->name);
				insmod(dirname, i->name, "", 0);
			}
		}
	}
	if (verbose)
		printf("Loading %s%s%s\n", dirname, realname,MODULE_EXTENSION);
	insmod(dirname, realname, options, 1);
}

static struct option options[] = { { "verbose", 0, NULL, 'v' },
				   { "version", 0, NULL, 'V' },
				   { "config", 1, NULL, 'C' },
				   { NULL, 0, NULL, 0 } };

#define DEFAULT_CONFIG "/etc/modprobe.conf"

int main(int argc, char *argv[])
{
	struct utsname buf;
	int opt;
	int verbose = 0;
	struct alias *aliases = NULL;
	struct module *modules;
	const char *config = NULL;
	char *dirname, *optstring = strdup("");

	try_old_version("modprobe", argv);

	while ((opt = getopt_long(argc, argv, "vVC:", options, NULL)) != -1) {
		switch (opt) {
		case 'v':
			verbose = 1;
			break;
		case 'V':
			printf("0.6\n");
			exit(0);
		case 'C':
			config = optarg;
			break;
		default:
			fprintf(stderr, "Unknown option `%s'\n", argv[optind]);
			print_usage(argv[0]);
		}
	}

	if (argc < optind + 1)
		print_usage(argv[0]);

	/* Rest is options */
	for (opt = optind + 1; opt < argc; opt++) {
		optstring = realloc(optstring,
				    strlen(optstring) + 2
				    + strlen(argv[opt]) + 2);
		/* Spaces handled by "" pairs, but no way of escaping
                   quotes */
		if (strchr(argv[opt], ' '))
			strcat(optstring, "\"");
		strcat(optstring, argv[opt]);
		if (strchr(argv[opt], ' '))
			strcat(optstring, "\"");
		strcat(optstring, " ");
	}

	uname(&buf);
	dirname = malloc(strlen(buf.release) + sizeof(MODULE_DIR));
	sprintf(dirname, MODULE_DIR, buf.release);

	/* Suck up the modules */
	modules = load_all_modules(dirname, &aliases);

	/* No config file specified?  Don't worry if it doesn't exist. */
	load_config_file(config ?: DEFAULT_CONFIG, config ? 1 : 0,
			 modules, &aliases);

	/* Special case for "symbol:" */
	if (strncmp(argv[optind], "symbol:", strlen("symbol:")) == 0)
		load_symbol(dirname, argv[optind] + strlen("symbol:"), 
			    modules, optstring, verbose);
	else
		load(dirname, argv[optind], modules, aliases, optstring,
		     verbose);
	exit(0);
}
