//===------------------------- AddressSpace.hpp ---------------------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is dual licensed under the MIT and the University of Illinois Open
// Source Licenses. See LICENSE.TXT for details.
//
//
// Abstracts accessing local vs remote address spaces.
//
//===----------------------------------------------------------------------===//

#ifndef __ADDRESSSPACE_HPP__
#define __ADDRESSSPACE_HPP__

#include <sys/rbtree.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <dlfcn.h>
#include <elf.h>
#include <link.h>
#include <pthread.h>

#include "dwarf2.h"

namespace _Unwind {

static int rangeCmp(void *, const void *, const void *);
static int rangeCmpKey(void *, const void *, const void *);
static int dsoTableCmp(void *, const void *, const void *);
static int dsoTableCmpKey(void *, const void *, const void *);
static int phdr_callback(struct dl_phdr_info *, size_t, void *);

struct unw_proc_info_t {
  uintptr_t data_base;       // Base address for data-relative relocations
  uintptr_t start_ip;        // Start address of function
  uintptr_t end_ip;          // First address after end of function
  uintptr_t lsda;            // Address of Language Specific Data Area
  uintptr_t handler;         // Personality routine
  uintptr_t extra_args;      // Extra stack space for frameless routines
  uintptr_t unwind_info;     // Address of DWARF unwind info
};

/// LocalAddressSpace is used as a template parameter to UnwindCursor when
/// unwinding a thread in the same process.  The wrappers compile away,
/// making local unwinds fast.
class LocalAddressSpace {
public:
  typedef uintptr_t pint_t;
  typedef intptr_t sint_t;

  typedef void (*findPCRange_t)(LocalAddressSpace &, pint_t, pint_t &pcStart,
                                pint_t &pcEnd);

  LocalAddressSpace(findPCRange_t findPCRange_)
      : findPCRange(findPCRange_), needsReload(true) {
    static const rb_tree_ops_t segmentTreeOps = {
      rangeCmp, rangeCmpKey, offsetof(Range, range_link), NULL
    };
    static const rb_tree_ops_t dsoTreeOps = {
      dsoTableCmp, dsoTableCmpKey, offsetof(Range, dso_link), NULL
    };
    rb_tree_init(&segmentTree, &segmentTreeOps);
    rb_tree_init(&dsoTree, &dsoTreeOps);
    pthread_rwlock_init(&fdeTreeLock, NULL);
  }

  uint8_t get8(pint_t addr) {
    uint8_t val;
    memcpy(&val, (void *)addr, sizeof(val));
    return val;
  }

  uint16_t get16(pint_t addr) {
    uint16_t val;
    memcpy(&val, (void *)addr, sizeof(val));
    return val;
  }

  uint32_t get32(pint_t addr) {
    uint32_t val;
    memcpy(&val, (void *)addr, sizeof(val));
    return val;
  }

  uint64_t get64(pint_t addr) {
    uint64_t val;
    memcpy(&val, (void *)addr, sizeof(val));
    return val;
  }

  uintptr_t getP(pint_t addr) {
    if (sizeof(uintptr_t) == sizeof(uint32_t))
      return get32(addr);
    else
      return get64(addr);
  }

  uint64_t getULEB128(pint_t &addr, pint_t end) {
    uint64_t result = 0;
    uint8_t byte;
    int bit = 0;
    do {
      uint64_t b;

      assert(addr != end);

      byte = get8(addr++);
      b = byte & 0x7f;

      assert(bit < 64);
      assert(b << bit >> bit == b);

      result |= b << bit;
      bit += 7;
    } while (byte >= 0x80);
    return result;
  }

  int64_t getSLEB128(pint_t &addr, pint_t end) {
    uint64_t result = 0;
    uint8_t byte;
    int bit = 0;
    do {
      uint64_t b;

      assert(addr != end);

      byte = get8(addr++);
      b = byte & 0x7f;

      assert(bit < 64);
      assert(b << bit >> bit == b);

      result |= b << bit;
      bit += 7;
    } while (byte >= 0x80);
    // sign extend negative numbers
    if ((byte & 0x40) != 0)
      result |= (~0ULL) << bit;
    return result;
  }

  pint_t getEncodedP(pint_t &addr, pint_t end, uint8_t encoding,
                     const unw_proc_info_t *ctx) {
    pint_t startAddr = addr;
    const uint8_t *p = (uint8_t *)addr;
    pint_t result;

    if (encoding == DW_EH_PE_omit)
      return 0;
    if (encoding == DW_EH_PE_aligned) {
      addr = (addr + sizeof(pint_t) - 1) & sizeof(pint_t);
      return getP(addr);
    }

    // first get value
    switch (encoding & 0x0F) {
    case DW_EH_PE_ptr:
      result = getP(addr);
      p += sizeof(pint_t);
      addr = (pint_t)p;
      break;
    case DW_EH_PE_uleb128:
      result = getULEB128(addr, end);
      break;
    case DW_EH_PE_udata2:
      result = get16(addr);
      p += 2;
      addr = (pint_t)p;
      break;
    case DW_EH_PE_udata4:
      result = get32(addr);
      p += 4;
      addr = (pint_t)p;
      break;
    case DW_EH_PE_udata8:
      result = get64(addr);
      p += 8;
      addr = (pint_t)p;
      break;
    case DW_EH_PE_sleb128:
      result = getSLEB128(addr, end);
      break;
    case DW_EH_PE_sdata2:
      result = (int16_t)get16(addr);
      p += 2;
      addr = (pint_t)p;
      break;
    case DW_EH_PE_sdata4:
      result = (int32_t)get32(addr);
      p += 4;
      addr = (pint_t)p;
      break;
    case DW_EH_PE_sdata8:
      result = get64(addr);
      p += 8;
      addr = (pint_t)p;
      break;
    case DW_EH_PE_omit:
      result = 0;
      break;
    default:
      assert(0 && "unknown pointer encoding");
    }

    // then add relative offset
    switch (encoding & 0x70) {
    case DW_EH_PE_absptr:
      // do nothing
      break;
    case DW_EH_PE_pcrel:
      result += startAddr;
      break;
    case DW_EH_PE_textrel:
      assert(0 && "DW_EH_PE_textrel pointer encoding not supported");
      break;
    case DW_EH_PE_datarel:
      assert(ctx != NULL && "DW_EH_PE_datarel without context");
      if (ctx)
        result += ctx->data_base;
      break;
    case DW_EH_PE_funcrel:
      assert(ctx != NULL && "DW_EH_PE_funcrel without context");
      if (ctx)
        result += ctx->start_ip;
      break;
    case DW_EH_PE_aligned:
      __builtin_unreachable();
    default:
      assert(0 && "unknown pointer encoding");
      break;
    }

    if (encoding & DW_EH_PE_indirect)
      result = getP(result);

    return result;
  }

  bool findFDE(pint_t pc, pint_t &fdeStart, pint_t &data_base) {
    Range *n;
    for (;;) {
      pthread_rwlock_rdlock(&fdeTreeLock);
      n = (Range *)rb_tree_find_node(&segmentTree, &pc);
      pthread_rwlock_unlock(&fdeTreeLock);
      if (n != NULL)
        break;
      if (!needsReload)
        break;
      lazyReload();
    }
    if (n == NULL)
      return false;
    if (n->hdr_start == 0) {
      fdeStart = n->hdr_base;
      data_base = n->data_base;
      return true;
    }

    pint_t base = n->hdr_base;
    pint_t first = n->hdr_start;
    for (pint_t len = n->hdr_entries; len > 1; ) {
      pint_t next = first + (len / 2) * 8;
      pint_t nextPC = base + (int32_t)get32(next);
      if (nextPC == pc) {
        first = next;
        break;
      }
      if (nextPC < pc) {
        first = next;
        len -= (len / 2);
      } else {
        len /= 2;
      }
    }
    fdeStart = base + (int32_t)get32(first + 4);
    data_base = n->data_base;
    return true;
  }

  bool addFDE(pint_t pcStart, pint_t pcEnd, pint_t fde) {
    Range *n = (Range *)malloc(sizeof(*n));
    n->hdr_base = fde;
    n->hdr_start = 0;
    n->hdr_entries = 0;
    n->first_pc = pcStart;
    n->last_pc = pcEnd;
    n->data_base = 0;
    n->ehframe_base = 0;
    pthread_rwlock_wrlock(&fdeTreeLock);
    if (static_cast<Range *>(rb_tree_insert_node(&segmentTree, n)) == n) {
      pthread_rwlock_unlock(&fdeTreeLock);
      return true;
    }
    free(n);
    pthread_rwlock_unlock(&fdeTreeLock);
    return false;
  }

  bool removeFDE(pint_t pcStart, pint_t pcEnd, pint_t fde) {
    pthread_rwlock_wrlock(&fdeTreeLock);
    Range *n = static_cast<Range *>(rb_tree_find_node(&segmentTree, &pcStart));
    if (n == NULL) {
      pthread_rwlock_unlock(&fdeTreeLock);
      return false;
    }
    assert(n->first_pc == pcStart);
    assert(n->last_pc == pcEnd);
    assert(n->hdr_base == fde);
    assert(n->hdr_start == 0);
    assert(n->hdr_entries == 0);
    assert(n->data_base == 0);
    assert(n->ehframe_base == 0);
    rb_tree_remove_node(&segmentTree, n);
    free(n);
    pthread_rwlock_unlock(&fdeTreeLock);
    return true;
  }

  void removeDSO(pint_t ehFrameBase) {
    pthread_rwlock_wrlock(&fdeTreeLock);
    Range *n;
    n = (Range *)rb_tree_find_node(&dsoTree, &ehFrameBase);
    if (n == NULL) {
      pthread_rwlock_unlock(&fdeTreeLock);
      return;
    }
    rb_tree_remove_node(&dsoTree, n);
    rb_tree_remove_node(&segmentTree, n);
    free(n);
    pthread_rwlock_unlock(&fdeTreeLock);
  }

  void setLazyReload() {
    pthread_rwlock_wrlock(&fdeTreeLock);
    needsReload = true;
    pthread_rwlock_unlock(&fdeTreeLock);
  }

private:
  findPCRange_t findPCRange;
  bool needsReload;
  pthread_rwlock_t fdeTreeLock;
  rb_tree_t segmentTree;
  rb_tree_t dsoTree;

  friend int phdr_callback(struct dl_phdr_info *, size_t, void *);
  friend int rangeCmp(void *, const void *, const void *);
  friend int rangeCmpKey(void *, const void *, const void *);
  friend int dsoTableCmp(void *, const void *, const void *);
  friend int dsoTableCmpKey(void *, const void *, const void *);

  void updateRange();

  struct Range {
    rb_node_t range_link;
    rb_node_t dso_link;
    pint_t hdr_base; // Pointer to FDE if hdr_start == 0
    pint_t hdr_start;
    pint_t hdr_entries;
    pint_t first_pc;
    pint_t last_pc;
    pint_t data_base;
    pint_t ehframe_base;
  };

  void lazyReload() {
    pthread_rwlock_wrlock(&fdeTreeLock);
    dl_iterate_phdr(phdr_callback, this);
    needsReload = false;
    pthread_rwlock_unlock(&fdeTreeLock);
  }

  void addDSO(pint_t header, pint_t data_base) {
    if (header == 0)
      return;
    if (get8(header) != 1)
      return;
    if (get8(header + 3) != (DW_EH_PE_datarel | DW_EH_PE_sdata4))
      return;
    pint_t end = header + 4;
    pint_t ehframe_base = getEncodedP(end, 0, get8(header + 1), NULL);
    pint_t entries = getEncodedP(end, 0, get8(header + 2), NULL);
    pint_t start = (end + 3) & ~pint_t(3);
    if (entries == 0)
      return;
    Range *n = (Range *)malloc(sizeof(*n));
    n->hdr_base = header;
    n->hdr_start = start;
    n->hdr_entries = entries;
    n->first_pc = header + (int32_t)get32(n->hdr_start);
    pint_t tmp;
    (*findPCRange)(
        *this, header + (int32_t)get32(n->hdr_start + (entries - 1) * 8 + 4),
        tmp, n->last_pc);
    n->data_base = data_base;
    n->ehframe_base = ehframe_base;

    if (static_cast<Range *>(rb_tree_insert_node(&segmentTree, n)) != n) {
      free(n);
      return;
    }
    rb_tree_insert_node(&dsoTree, n);
  }
};

static int phdr_callback(struct dl_phdr_info *info, size_t size, void *data_) {
  LocalAddressSpace *data = (LocalAddressSpace *)data_;
  size_t eh_frame = 0, data_base = 0;
  const Elf_Phdr *hdr = info->dlpi_phdr;
  const Elf_Phdr *last_hdr = hdr + info->dlpi_phnum;
  const Elf_Dyn *dyn;

  for (; hdr != last_hdr; ++hdr) {
    switch (hdr->p_type) {
    case PT_GNU_EH_FRAME:
      eh_frame = info->dlpi_addr + hdr->p_vaddr;
      break;
    case PT_DYNAMIC:
      dyn = (const Elf_Dyn *)(info->dlpi_addr + hdr->p_vaddr);
      while (dyn->d_tag != DT_NULL) {
        if (dyn->d_tag == DT_PLTGOT) {
          data_base = info->dlpi_addr + dyn->d_un.d_ptr;
          break;
        }
        ++dyn;
      }
    }
  }

  if (eh_frame)
    data->addDSO(eh_frame, data_base);

  return 0;
}

static int rangeCmp(void *context, const void *n1_, const void *n2_) {
  const LocalAddressSpace::Range *n1 = (const LocalAddressSpace::Range *)n1_;
  const LocalAddressSpace::Range *n2 = (const LocalAddressSpace::Range *)n2_;

  if (n1->first_pc < n2->first_pc)
    return -1;
  if (n1->first_pc > n2->first_pc)
    return 1;
  assert(n1->last_pc == n2->last_pc);
  return 0;
}

static int rangeCmpKey(void *context, const void *n_, const void *pc_) {
  const LocalAddressSpace::Range *n = (const LocalAddressSpace::Range *)n_;
  const LocalAddressSpace::pint_t *pc = (const LocalAddressSpace::pint_t *)pc_;
  if (n->last_pc < *pc)
    return -1;
  if (n->first_pc > *pc)
    return 1;
  return 0;
}

static int dsoTableCmp(void *context, const void *n1_, const void *n2_) {
  const LocalAddressSpace::Range *n1 = (const LocalAddressSpace::Range *)n1_;
  const LocalAddressSpace::Range *n2 = (const LocalAddressSpace::Range *)n2_;

  if (n1->ehframe_base < n2->ehframe_base)
    return -1;
  if (n1->ehframe_base > n2->ehframe_base)
    return 1;
  return 0;
}

static int dsoTableCmpKey(void *context, const void *n_, const void *ptr_) {
  const LocalAddressSpace::Range *n = (const LocalAddressSpace::Range *)n_;
  const LocalAddressSpace::pint_t *ptr = (const LocalAddressSpace::pint_t *)ptr_;
  if (n->ehframe_base < *ptr)
    return -1;
  if (n->ehframe_base > *ptr)
    return 1;
  return 0;
}

} // namespace _Unwind

#endif // __ADDRESSSPACE_HPP__