#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <coff.h>
#include <limits.h>

#include "dll.h"

#define PRINT 0

typedef void (*CDTOR)();

struct Symtab {
  Symtab *next, *prev;
  static Symtab *symtabs;
  char **name;
  void **value;
  int num;
  int max;
  Symtab();
  ~Symtab();
  void add(char *name, void *value);
  void *get(char *name);
  static void *lookup(char *name);
};

struct dll_s {
  dll_s *next, *prev;
  char *loadpath;
  int valid;
  static dll_s *top;
  int load_count;
  int num_sections;
  char *bytes;
  char *dtor_section;
  int dtor_count;
  Symtab symtab;
  CDTOR uninit_func;
};

dll_s *dll_s::top = 0;

Symtab *Symtab::symtabs = 0;

Symtab::Symtab()
{
  next = symtabs;
  prev = 0;
  symtabs = this;
  num = 0;
  max = 10;
  name = (char **)malloc(max*sizeof(char *));
  value = (void **)malloc(max*sizeof(void *));
}

Symtab::~Symtab()
{
  int i;
  for (i=0; i<num; i++)
    free(name[i]);
  free(name);
  free(value);
  if (next)
    next->prev = prev;
  if (prev)
    prev->next = next;
  else
    symtabs = next;
}

void Symtab::add(char *Pname, void *Pvalue)
{
  if (num >= max)
  {
    max += 10;
    name = (char **)realloc(name, max * sizeof(char *));
    value = (void **)realloc(value, max * sizeof(void *));
  }
  name[num] = strdup(Pname);
  value[num] = Pvalue;
  num++;
}

void *Symtab::get(char *Pname)
{
  int i;
  for (i=0; i<num; i++)
    if (strcmp(Pname, name[i]) == 0)
      return value[i];
  return 0;
}

void *Symtab::lookup(char *Pname)
{
  Symtab *s;
  for (s=symtabs; s; s=s->next)
  {
    void *v = s->get(Pname);
    if (v)
    {
      return v;
    }
  }
  return 0;
}

static struct {
  int val;
  char *name;
} flags[] = {
  F_RELFLG, "REL",
  F_EXEC, "EXEC",
  F_LNNO, "LNNO",
  F_LSYMS, "LSYMS",
  0, 0
};

static struct {
  int val;
  char *name;
} sflags[] = {
  STYP_TEXT, "text",
  STYP_DATA, "data",
  STYP_BSS, "bss",
  0, 0
};

static char *dll_argv0 = 0;
static Symtab *local_symtab = 0;
static Symtab *common_symtab = 0;

static void dll_exitfunc(void)
{
  while (dll_s::top)
    dll_unload((struct DLL *)dll_s::top);
}

void dll_register(char *Psymbol, void *Paddress)
{
  if (local_symtab == 0)
    local_symtab = new Symtab;
  if (common_symtab == 0)
    common_symtab = new Symtab;
  local_symtab->add(Psymbol, Paddress);
  local_symtab->get(Psymbol);
}

#define MKDLL(a,b) extern void b();
extern "C" {
#include "gccdll.h"
};
#undef MKDLL

void dll_init(char *argv0)
{
  atexit(dll_exitfunc);
  if (dll_argv0)
    free(dll_argv0);
  else
  {
    dll_register("_dll_load", dll_load);
    dll_register("_dll_unload", dll_unload);
    dll_register("_dll_lookup", dll_lookup);
#define MKDLL(a,b) dll_register(a, b);
#include "gccdll.h"
#undef MKDLL
  }
  dll_argv0 = strdup(argv0);
}

char *find_file(char *fn)
{
  char *bp, *ep, *pp;
  static char buf[PATH_MAX];
  if (strpbrk(fn, ":\\/"))
    return fn;

//  printf("find: try `%s'\n", fn);
  if (access(fn,0) == 0)
    return fn;

  if (dll_argv0)
  {
    strcpy(buf, dll_argv0);
    ep = buf;
    for (bp=buf; *bp; bp++)
      if (strchr(":\\/", *bp))
        ep = bp+1;
    strcpy(ep, fn);
//    printf("find: try `%s'\n", buf);
    if (access(buf, 0) == 0)
      return buf;
  }
  
  bp = getenv("PATH");
  while (*bp)
  {
    pp = buf;
    while (*bp && *bp != ';')
      *pp++ = *bp++;
    *pp++ = '/';
    strcpy(pp, fn);
//    printf("find: try `%s'\n", buf);
    if (access(buf, 0) == 0)
      return buf;
    if (*bp == 0)
      break;
    bp++;
  }
//  printf("find: default `%s'\n", fn);
  return fn;
}

struct DLL *dll_load(char *filename)
{
  if (dll_argv0 == 0)
    dll_init("");

  dll_s *dll;
  for (dll=dll_s::top; dll; dll=dll->next)
    if (strcmp(dll->loadpath, filename) == 0)
    {
      dll->load_count ++;
      return (DLL *)dll;
    }

  return dll_force_load(filename);
}

struct DLL *dll_force_load(char *filename)
{
  int i, s;
  int error = 0;
  dll_s *dll;

  char *loadpath = find_file(filename);
 
//  printf("load: `%s'\n", loadpath);
  FILE *file = fopen(loadpath, "rb");
  if (file == 0)
  {
    fprintf(stderr, "Error: unable to load %s\n", filename);
    perror("The error was");
    return 0;
  }

  dll = new dll_s;
  dll->valid = 0;
  dll->loadpath = strdup(filename);
  dll->num_sections = 0;
  dll->next = dll_s::top;
  if (dll->next)
    dll->next->prev = dll;
  dll_s::top = dll;
  dll->prev = 0;
  dll->load_count = 1;
  dll->dtor_count = 0;
  dll->uninit_func = 0;
  CDTOR init_func = 0;
  
  FILHDR filhdr;
  fread(&filhdr, 1, sizeof(filhdr), file);
#if PRINT
  printf("file: %s, magic=%#x\n", filename, filhdr.f_magic);
#endif
  if (filhdr.f_magic != 0x14c)
  {
    fprintf(stderr, "Not a COFF file\n");
    return 0;
  }
#if PRINT
  printf("nscns=%d, nsyms=%d, symptr=%#x\n", filhdr.f_nscns, filhdr.f_nsyms,
    filhdr.f_symptr);
  printf("flags: ");
  for (i=0; flags[i].val; i++)
    if (filhdr.f_flags & flags[i].val)
      printf(" %s", flags[i].name);
  printf("\n");
#endif

  if (filhdr.f_opthdr)
    fseek(file, filhdr.f_opthdr, 1);

  SCNHDR *section;
  section = new SCNHDR[filhdr.f_nscns];
  dll->num_sections = filhdr.f_nscns;
  fread(section, sizeof(SCNHDR), filhdr.f_nscns, file);

  int max_bytes = 0;
  for (s=0; s<filhdr.f_nscns; s++)
    if (max_bytes < section[s].s_vaddr + section[s].s_size)
      max_bytes = section[s].s_vaddr + section[s].s_size;

  dll->bytes = new char[max_bytes];

  for (s=0; s<filhdr.f_nscns; s++)
  {
    if (section[s].s_scnptr)
    {
#if PRINT
      printf("section %d from file at 0x%x\n", s, dll->bytes + section[s].s_vaddr);
#endif
      if (section[s].s_size)
      {
        fseek(file, section[s].s_scnptr, 0);
        fread(dll->bytes+section[s].s_vaddr, 1, section[s].s_size, file);
      }
    }
    else
    {
#if PRINT
      printf("section %d zeroed %d bytes at 0x%x\n", s, section[s].s_size, dll->bytes+section[s].s_vaddr);
#endif
      if (section[s].s_size)
        memset(dll->bytes+section[s].s_vaddr, 0, section[s].s_size);
    }
  }
  
  SYMENT *syment = new SYMENT[filhdr.f_nsyms];
  unsigned long *symaddr = new unsigned long [filhdr.f_nsyms];
  fseek(file, filhdr.f_symptr, 0);
  fread(syment, filhdr.f_nsyms, SYMESZ, file);
  unsigned long strsize = 4;
  fread(&strsize, 1, sizeof(unsigned long), file);
  char *strings = new char[strsize];
  strings[0] = 0;
  if (strsize > 4)
    fread(strings+4, strsize-4, 1, file);

  for (i=0; i<filhdr.f_nsyms; i++)
  {
    char snameb[9], *sname, *scname;
#if PRINT
    printf("[0x%08x] ", syment[i].e_value);
#endif
    if (syment[i].e.e.e_zeroes)
    {
      sprintf(snameb, "%.8s", syment[i].e.e_name);
      sname = snameb;
    }
    else
      sname = strings + syment[i].e.e.e_offset;

    if (syment[i].e_scnum > 0)
    {
      symaddr[i] = syment[i].e_value + (long)(dll->bytes);
      if (syment[i].e_sclass == 2)
        dll->symtab.add(sname, (void *)symaddr[i]);
      if (strcmp(sname, "_dll_unloadfunc") == 0)
        dll->uninit_func = (CDTOR)symaddr[i];
      if (strcmp(sname, "_dll_loadfunc") == 0)
        init_func = (CDTOR)symaddr[i];
    }
    else if (syment[i].e_scnum == N_UNDEF)
    {
      if (syment[i].e_value)
      {
        void *stv = common_symtab->get(sname);
        if (stv)
          symaddr[i] = (long)stv;
        else
        {
          stv = calloc(syment[i].e_value,1);
          common_symtab->add(sname, stv);
          symaddr[i] = (long)stv;
        }
      }
      else
      {
        symaddr[i] = (long)Symtab::lookup(sname);
        if (symaddr[i] == 0)
        {
          fprintf(stderr, "Undefined symbol %s referenced from %s\n",
            sname, filename);
          error = 1;
        }
      }
    }

#if PRINT
    if (syment[i].e_scnum >= 1)
      scname = section[syment[i].e_scnum-1].s_name;
    else
      scname = "N/A";
    printf("[%2d] 0x%08x %2d %-8.8s %04x %02x %d %s\n", i,
      symaddr[i],
      syment[i].e_scnum,
      scname,
      syment[i].e_type,
      syment[i].e_sclass,
      syment[i].e_numaux,
      sname);
    for (int a=0; a<syment[i].e_numaux; a++)
    {
      i++;
#if 0
      unsigned char *ap = (unsigned char *)(syment+i);
      printf("\033[0m");
      for (int b=0; b<SYMESZ; b++)
        printf(" %02x\033[32m%c\033[37m", ap[b], ap[b]>' '?ap[b]:' ');
      printf("\033[1m\n");
#endif
    }
#else
    i += syment[i].e_numaux;
#endif
  }

  for (s=0; s<filhdr.f_nscns; s++)
  {
#if PRINT
    printf("\nS[%d] `%-8s' pa=%#x va=%#x s=%#x ptr=%#x\n",
      s, section[s].s_name, section[s].s_paddr,section[s].s_vaddr,
      section[s].s_size, section[s].s_scnptr);
    printf("  rel=%#x nrel=%#x  flags: ",
      section[s].s_relptr, section[s].s_nreloc);
    for (i=0; sflags[i].val; i++)
      if (section[s].s_flags & sflags[i].val)
        printf(" %s", sflags[i].name);
    printf("\n");
#endif

    if (section[s].s_nreloc)
    {
      fseek(file, section[s].s_relptr, 0);
      RELOC *r = new RELOC[section[s].s_nreloc];
      fread(r, RELSZ, section[s].s_nreloc, file);
      for (i=0; i<section[s].s_nreloc; i++)
      {
        long *ptr = (long *)(dll->bytes + r[i].r_vaddr);
        long old_value = *ptr;
#if PRINT
        printf("  [%02d]  0x%08x(0x%08x)  %2d  0x%04x 0%02o (was 0x%08x",
          i, r[i].r_vaddr, ptr, r[i].r_symndx, r[i].r_type, r[i].r_type, old_value);
#endif
        switch (r[i].r_type)
        {
          case 0x06:
            old_value -= syment[r[i].r_symndx].e_value;
            old_value += symaddr[r[i].r_symndx];
            break;
          case 0x14:
            if (syment[r[i].r_symndx].e_scnum == 0)
            {
              old_value -= (long)(dll->bytes);
              old_value += symaddr[r[i].r_symndx];
            }
            break;
          default:
            fprintf(stderr, "Error: unexpected relocation type %#x\n",
              r[i].r_type);
            error = 1;
        }
        *ptr = old_value;
#if PRINT
        printf(", now 0x%08x)\n", old_value);
#endif
      }
      delete r;
    }
  }
  for (s=0; s<filhdr.f_nscns; s++)
  {
    if (strcmp(section[s].s_name, ".ctor") == 0)
    {
      for (i=0; i<section[s].s_size/4; i++)
      {
        CDTOR f;
        void **fv = (void **)(dll->bytes+section[s].s_vaddr);
        f = (CDTOR)(fv[i]);
        f();
      }
    }
    if (strcmp(section[s].s_name, ".dtor") == 0)
    {
      dll->dtor_section = dll->bytes + section[s].s_vaddr;
      dll->dtor_count = section[s].s_size/4;
    }
  }
  if (init_func)
    init_func();

  dll->valid = 1;
  fclose(file);
  delete syment;
  delete symaddr;
  delete strings;
  delete section;

  if (error)
    return 0;
  return (struct DLL *)dll;
}

void dll_unload(struct DLL *Pdll)
{
  int i, s;
  CDTOR f;
  if (Pdll == 0)
    return;
  dll_s *dll = (dll_s *)Pdll;
  if (--dll->load_count)
    return;
//  printf("unload: `%s'\n", dll->loadpath);
  if (dll->valid)
  {
    if (dll->uninit_func)
      dll->uninit_func();
    for (i=0; i<dll->dtor_count; i++)
    {
      void **fv = (void **)(dll->dtor_section);
      f = (CDTOR)(fv[i]);
      f();
    }
  }
  if (dll->next)
    dll->next->prev = dll->prev;
  if (dll->prev)
    dll->prev->next = dll->next;
  else
    dll_s::top = dll->next;
  if (dll->bytes)
    delete dll->bytes;
  dll->valid = 0;
  delete dll;
}

void *dll_lookup(struct DLL *Pdll, char *name)
{
  dll_s *dll = (dll_s *)Pdll;
  return dll->symtab.get(name);
}
