/*
 * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * * Redistributions of source code must retain the above copyright
 *   notice, this list of conditions and the following disclaimer.
 * * Redistributions in binary form must reproduce the above copyright
 *   notice, this list of conditions and the following disclaimer in the
 *   documentation and/or other materials provided with the distribution.
 * * Neither the name of NVIDIA CORPORATION nor the names of its
 *   contributors may be used to endorse or promote products derived
 *   from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/*! \file jitify2.hpp
 *  \brief The Jitify v2 library header
 */

/*! \mainpage Jitify - A C++ library that simplifies the use of NVRTC
 *  \p Use class jitify2::ProgramCache to manage and launch JIT-compiled CUDA
 *    kernels.
 *
 *  \p Use namespace jitify2::reflection to reflect types and values into
 *    code-strings.
 */

#ifndef JITIFY2_HPP_INCLUDE_GUARD
#define JITIFY2_HPP_INCLUDE_GUARD

// This macro is used by source files generated by jitify_preprocess to avoid
// unnecessary dependencies.
#ifdef JITIFY_SERIALIZATION_ONLY

#include <climits>
#include <iostream>
#include <sstream>
#include <streambuf>
#include <string>
#include <unordered_map>
#include <vector>

#if __cplusplus >= 201703L
#include <string_view>
#endif

#else  // not JITIFY_SERIALIZATION_ONLY

#include <cuda.h>
#include <nvrtc.h>

// Default to being thread-safe.
#ifndef JITIFY_THREAD_SAFE
#define JITIFY_THREAD_SAFE 1
#endif

// Default to using dynamic linking of NVRTC.
#ifndef JITIFY_LINK_NVRTC_STATIC
#define JITIFY_LINK_NVRTC_STATIC 0
#endif

// Users can enable this for easier debugging.
#ifndef JITIFY_FAIL_IMMEDIATELY
#define JITIFY_FAIL_IMMEDIATELY 0
#endif

#ifndef JITIFY_USE_LIBCUFILT
#define JITIFY_USE_LIBCUFILT 0  // Use Jitify's builtin demangler by default
#endif

#if CUDA_VERSION >= 11040 && JITIFY_USE_LIBCUFILT
#include <nv_decode.h>  // For __cu_demangle (requires linking with libcufilt.a)
#endif

#include <algorithm>
#include <cctype>
#include <climits>
#include <cstring>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <regex>
#include <sstream>
#include <streambuf>
#include <string>
#include <thread>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#if JITIFY_THREAD_SAFE
#include <mutex>
#define JITIFY_IF_THREAD_SAFE(x) x
#else
#define JITIFY_IF_THREAD_SAFE(x)
#endif

#ifdef __linux__
#include <cxxabi.h>             // For abi::__cxa_demangle
#include <dirent.h>             // For struct dirent, opendir etc.
#include <dlfcn.h>              // For ::dlopen, ::dlsym etc.
#include <fcntl.h>              // For open
#include <linux/limits.h>       // For PATH_MAX
#include <sys/stat.h>           // For stat
#include <sys/types.h>          // For DIR etc.
#include <unistd.h>             // For close
#include <cstdlib>              // For realpath
#include <ext/stdio_filebuf.h>  // For __gnu_cxx::stdio_filebuf
#define JITIFY_PATH_MAX PATH_MAX
#elif defined(_WIN32) || defined(_WIN64)
#include <windows.h>  // Must be included first

#include <dbghelp.h>      // For UndecorateSymbolName
#include <direct.h>       // For mkdir
#include <fcntl.h>        // For open, O_RDWR etc.
#include <io.h>           // For _sopen_s
#include <sys/locking.h>  // For _LK_LOCK etc.
#define JITIFY_PATH_MAX MAX_PATH
#else
#error "Unsupported platform"
#endif

#if defined(_WIN32) || defined(_WIN64)
// WAR for strtok_r being called strtok_s on Windows.
#pragma push_macro("strtok_r")
#undef strtok_r
#define strtok_r strtok_s
// WAR for min and max possibly being macros defined by windows.h
#pragma push_macro("min")
#pragma push_macro("max")
#undef min
#undef max
#endif

#ifndef JITIFY_ENABLE_EXCEPTIONS
// Default to using exceptions.
#define JITIFY_ENABLE_EXCEPTIONS 1
#endif

#if JITIFY_ENABLE_EXCEPTIONS
#include <stdexcept>
#define JITIFY_THROW_OR_TERMINATE(msg) throw std::runtime_error(msg)
#else
// TODO: Would std::exit or std::abort be better than std::terminate?
#include <exception>
#define JITIFY_THROW_OR_TERMINATE(msg)              \
  std::cerr << "Fatal error: " << msg << std::endl; \
  std::terminate()
#endif

#if JITIFY_ENABLE_EXCEPTIONS
#define JITIFY_THROW_OR_RETURN(msg) throw std::runtime_error(msg)
#else
#define JITIFY_THROW_OR_RETURN(msg) return msg
#endif

#define JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(call) \
  do {                                             \
    CUresult jitify_cuda_ret = call;               \
    if (jitify_cuda_ret != CUDA_SUCCESS) {         \
      const char* error_c;                         \
      cuGetErrorString(jitify_cuda_ret, &error_c); \
      JITIFY_THROW_OR_RETURN(error_c);             \
    }                                              \
  } while (0)

#endif  // not JITIFY_SERIALIZATION_ONLY

namespace jitify2 {

// Convenience aliases.
using StringVec = std::vector<std::string>;
using StringMap = std::unordered_map<std::string, std::string>;

#if __cplusplus >= 201703L
using StringRef = std::string_view;
using StringSlice = std::string_view;
#else
using StringRef = const std::string&;
using StringSlice = std::string;
#endif

namespace serialization {

// Stream buffer that can be initialized with data without copying.
// Based on https://stackoverflow.com/a/13059195/7228843
struct membuf : std::streambuf {
  membuf(const char* data, size_t size) {
    char* data_workaround(const_cast<char*>(data));
    this->setg(data_workaround, data_workaround, data_workaround + size);
  }
};
// Warning: Do not put this inside the serialization::detail namespace, lest the
// wrath of ADL come down upon you from serialization::deserialize(StringRef).
struct imemstream : virtual membuf, std::istream {
  imemstream(const char* data, size_t size)
      : membuf(data, size), std::istream(static_cast<std::streambuf*>(this)) {}
  imemstream(const std::string& str) : imemstream(str.data(), str.size()) {}
#if __cplusplus >= 201703L
  imemstream(std::string_view sv) : imemstream(sv.data(), sv.size()) {}
#endif
};

// This should be incremented whenever the serialization format changes in any
// incompatible way.
static constexpr const size_t kSerializationVersion = 0x10;

namespace detail {

inline void serialize(std::ostream& stream, size_t u) {
  uint64_t u64 = u;
  char bytes[8];
  for (int i = 0; i < (int)sizeof(bytes); ++i) {
    // Convert to little-endian bytes.
    bytes[i] = (unsigned char)(u64 >> (i * CHAR_BIT));
  }
  stream.write(bytes, sizeof(bytes));
}

inline bool deserialize(std::istream& stream, size_t* size) {
  char bytes[8];
  stream.read(bytes, sizeof(bytes));
  uint64_t u64 = 0;
  for (int i = 0; i < (int)sizeof(bytes); ++i) {
    // Convert from little-endian bytes.
    u64 |= uint64_t((unsigned char)(bytes[i])) << (i * CHAR_BIT);
  }
  *size = u64;
  return stream.good();
}

// Obfuscate so that embedded serializations don't show up in `strings`.
inline std::string obfuscate(std::string s) {
  for (char& c : s) {
    c = (char)-c;
  }
  return s;
}
inline std::string deobfuscate(std::string s) {
  return obfuscate(std::move(s));
}

inline void serialize(std::ostream& stream, std::string s) {
  serialize(stream, s.size());
  stream.write(obfuscate(s).data(), s.size());
}

inline bool deserialize(std::istream& stream, std::string* s) {
  size_t size;
  if (!deserialize(stream, &size)) return false;
  s->resize(size);
  if (s->size()) {
    stream.read(&(*s)[0], s->size());
  }
  *s = deobfuscate(std::move(*s));
  return stream.good();
}

inline void serialize(std::ostream& stream, const StringVec& v) {
  serialize(stream, v.size());
  for (const auto& s : v) {
    serialize(stream, s);
  }
}

inline bool deserialize(std::istream& stream, StringVec* v) {
  size_t size;
  if (!deserialize(stream, &size)) return false;
  v->resize(size);
  for (auto& s : *v) {
    if (!deserialize(stream, &s)) return false;
  }
  return true;
}

inline void serialize(std::ostream& stream, const StringMap& m) {
  serialize(stream, m.size());
  for (const auto& kv : m) {
    serialize(stream, kv.first);
    serialize(stream, kv.second);
  }
}

inline bool deserialize(std::istream& stream, StringMap* m) {
  size_t size;
  if (!deserialize(stream, &size)) return false;
  for (size_t i = 0; i < size; ++i) {
    std::string key;
    if (!deserialize(stream, &key)) return false;
    if (!deserialize(stream, &(*m)[key])) return false;
  }
  return true;
}

template <typename T, typename... Rest>
inline void serialize(std::ostream& stream, const T& value,
                      const Rest&... rest) {
  serialize(stream, value);
  serialize(stream, rest...);
}

template <typename T, typename... Rest>
inline bool deserialize(std::istream& stream, T* value, Rest*... rest) {
  if (!deserialize(stream, value)) return false;
  return deserialize(stream, rest...);
}

inline void serialize_magic_number(std::ostream& stream) {
  stream.write("JTFY", 4);
  serialize(stream, kSerializationVersion);
}

inline bool deserialize_magic_number(std::istream& stream) {
  char magic_number[4] = {0, 0, 0, 0};
  stream.read(&magic_number[0], 4);
  if (!(magic_number[0] == 'J' && magic_number[1] == 'T' &&
        magic_number[2] == 'F' && magic_number[3] == 'Y')) {
    return false;
  }
  size_t serialization_version;
  if (!deserialize(stream, &serialization_version)) return false;
  return serialization_version == kSerializationVersion;
}

}  // namespace detail

template <typename... Values>
inline void serialize(std::ostream& stream, const Values&... values) {
  detail::serialize_magic_number(stream);
  detail::serialize(stream, values...);
}

template <typename T, typename... Rest,
          typename std::enable_if<
              !std::is_convertible<T&, std::ostream&>::value, int>::type = 0>
inline std::string serialize(const T& value, const Rest&... rest) {
  std::ostringstream ss(std::stringstream::binary);
  detail::serialize_magic_number(ss);
  detail::serialize(ss, value, rest...);
  return ss.str();
}

template <typename... Values>
inline bool deserialize(std::istream& stream, Values*... values) {
  if (!detail::deserialize_magic_number(stream)) return false;
  return detail::deserialize(stream, values...);
}

template <typename... Values>
inline bool deserialize(StringRef serialized, Values*... values) {
  imemstream ms(serialized);
  return deserialize(ms, values...);
}

template <class Subclass>
class Serializable {
  struct SerializeImpl {
    std::ostream& stream_;
    SerializeImpl(std::ostream& stream) : stream_(stream) {}
    template <typename... Values>
    bool operator()(const Values&... values) const {
      serialization::serialize(stream_, values...);
      return true;
    }
  };
  struct DeserializeImpl {
    std::istream& stream_;
    DeserializeImpl(std::istream& stream) : stream_(stream) {}
    template <typename... Values>
    bool operator()(Values&... values) const {
      return serialization::deserialize(stream_, &values...);
    }
  };

 public:
  /*! Serialize the object to a stream.
   *  \param stream The stream to output serialized data to.
   */
  void serialize(std::ostream& stream) const {
    const auto* subclass = static_cast<const Subclass*>(this);
    subclass->serialize_members(SerializeImpl(stream));
  }
  /*! Serialize the object to a string.
   *  \return A string containing the serialized data.
   */
  std::string serialize() const {
    std::ostringstream ss(std::stringstream::binary);
    serialize(ss);
    return ss.str();
  }
  static bool deserialize(std::istream& stream, Subclass* subclass) {
    return subclass->deserialize_members(DeserializeImpl(stream));
  }
  static bool deserialize(StringRef serialized, Subclass* subclass) {
    imemstream ms(serialized);
    return subclass->deserialize_members(DeserializeImpl(ms));
  }
};

#define JITIFY_DEFINE_SERIALIZABLE_MEMBERS(ClassName, ...) \
  friend class serialization::Serializable<ClassName>;     \
  template <typename Deserializer>                         \
  bool deserialize_members(Deserializer deserializer) {    \
    return deserializer(__VA_ARGS__);                      \
  }                                                        \
  template <typename Serializer>                           \
  bool serialize_members(Serializer serializer) const {    \
    return serializer(__VA_ARGS__);                        \
  }

}  // namespace serialization

#ifndef JITIFY_SERIALIZATION_ONLY

namespace detail {

// inline const std::string& to_string(const std::string& s) { return s; }
// TODO: Double-check that this is OK
inline StringRef to_string(StringRef s) { return s; }

template <class Func, typename... Args>
inline void for_each(Func function, Args&&... args) {
  // Convenient trick to reduce over variadic template args.
  int unpack[] = {0, (function(std::forward<Args>(args)), 0)...};
  (void)unpack;  // Avoid compiler warning about being unused.
}

template <typename... Args>
inline std::string string_concat_strings(const Args&... args) {
  size_t size = 0;
  for_each([&](StringRef arg) { size += arg.size(); }, args...);
  std::string result;
  result.reserve(size);
  for_each([&](StringRef arg) { result += arg; }, args...);
  return result;
}

template <typename... Args>
inline std::string string_concat(const Args&... args) {
  using ::jitify2::detail::to_string;
  using std::to_string;
  return string_concat_strings(to_string(args)...);
}

inline std::string string_join(const StringVec& args, StringRef sep = ",",
                               StringRef prefix = "", StringRef suffix = "") {
  std::string result;
  size_t args_size = 0;
  for (const std::string& arg : args) {
    args_size += arg.size();
  }
  result.reserve(prefix.size() + args_size +
                 sep.size() * (std::max(args.size(), size_t(1)) - 1) +
                 suffix.size());
  result += prefix;
  for (int i = 0; i < (int)args.size(); ++i) {
    if (i > 0) result += sep;
    result += args[i];
  }
  result += suffix;
  return result;
}

// Strip whitespace from string in-place.
inline void ltrim(std::string* s) {
  s->erase(s->begin(), std::find_if(s->begin(), s->end(), [](unsigned char c) {
             return !std::isspace(c);
           }));
}
inline void rtrim(std::string* s) {
  s->erase(std::find_if(s->rbegin(), s->rend(),
                        [](unsigned char c) { return !std::isspace(c); })
               .base(),
           s->end());
}
inline void trim(std::string* s) {
  ltrim(s);
  rtrim(s);
}

// Strip whitespace from a string view.
inline StringSlice ltrim(StringRef s) {
  size_t beg = std::find_if(s.begin(), s.end(),
                            [](unsigned char c) { return !std::isspace(c); }) -
               s.begin();
  return s.substr(beg);
}
inline StringSlice rtrim(StringRef s) {
  size_t end = std::find_if(s.rbegin(), s.rend(),
                            [](unsigned char c) { return !std::isspace(c); })
                   .base() -
               s.begin();
  return s.substr(0, end);
}
inline StringSlice trim(StringRef s) { return rtrim(ltrim(s)); }

}  // namespace detail

/*! Reflection utilities namespace. */
namespace reflection {

template <typename T, T VALUE>
struct NonType {};

// Forward declaration.
template <typename T>
inline std::string reflect(const T& value);

namespace detail {

template <typename T>
inline std::string value_string(const T& x) {
  return std::to_string(x);
}

template <>
inline std::string value_string<bool>(const bool& x) {
  return x ? "true" : "false";
}

// Returns the demangled name corresponding to the given typeinfo structure.
inline std::string get_type_name(const std::type_info& typeinfo) {
#ifdef _MSC_VER  // MSVC compiler
  // Get the decorated name and skip over the leading '.'.
  const char* raw_name = typeinfo.raw_name();
  if (!raw_name || raw_name[0] != '.') return {};  // Unexpected error
  const char* decorated_name = raw_name + 1;
  char undecorated_name[4096];
  // Note: UNDNAME_NO_MS_KEYWORDS removes __cdecl, __ptr64 etc. but has a bug in
  // some versions that breaks function types. Instead, we leave these tokens in
  // and #define them away as necessary.
  if (!UnDecorateSymbolName(
          decorated_name, undecorated_name,
          sizeof(undecorated_name) / sizeof(*undecorated_name),
          UNDNAME_NO_ARGUMENTS |          // Treat input as a type name
              UNDNAME_NAME_ONLY           // No "class" and "struct" prefixes
          /*UNDNAME_NO_MS_KEYWORDS*/)) {  // No "__cdecl", "__ptr64" etc. BUGGED
    return {};                            // Error
  }
  return undecorated_name;
#else   // not MSVC
  const char* mangled_name = typeinfo.name();
  size_t bufsize = 0;
  char* buf = nullptr;
  int status;
  auto demangled_ptr = std::unique_ptr<char, void (*)(void*)>(
      abi::__cxa_demangle(mangled_name, buf, &bufsize, &status), std::free);
  // clang-format off
  switch (status) {
  case 0: return demangled_ptr.get();  // Demangled successfully
  case -2: return mangled_name;        // Interpret as plain unmangled name
  case -1: // fall-through             // Memory allocation failure
  case -3: // fall-through             // Invalid argument
  default: return {};
  }
    // clang-format on
#endif  // not MSVC
}

template <typename>
class JitifyTypeNameWrapper_ {};

// Returns the demangled name of the given type.
template <typename T>
inline std::string get_type_name() {
  // WAR for typeid discarding cv qualifiers on value-types.
  // Wraps type in dummy template class to preserve cv-qualifiers, then strips
  // off the wrapper from the resulting string.
  std::string wrapped_name = get_type_name(typeid(JitifyTypeNameWrapper_<T>));
  // Note: The reflected name of this class also has namespace prefixes.
  const std::string wrapper_class_name = "JitifyTypeNameWrapper_<";
  size_t start = wrapped_name.find(wrapper_class_name);
  if (start == std::string::npos) return {};  // Unexpected error
  start += wrapper_class_name.size();
  return wrapped_name.substr(start, wrapped_name.size() - (start + 1));
}

template <typename T>
struct ReflectType {
  const std::string& operator()() const {
    // Storing this statically means it is cached after the first call.
    static const std::string type_name = get_type_name<T>();
    return type_name;
  }
};

template <typename T, T VALUE>
struct ReflectType<NonType<T, VALUE>> {
  std::string operator()() const { return reflect(VALUE); }
};

}  // namespace detail

/*! A wrapper used for representing types as values. */
template <typename T>
struct Type {};

/*! Create an Instance object that contains a const reference to the
 *  value.  We use this to wrap abstract objects from which we want to extract
 *  their type at runtime (e.g., derived type).  This is used to facilitate
 *  templating on derived type when all we know at compile time is abstract
 * type.
 */
template <typename T>
struct Instance {
  const T& value;
  Instance(const T& value_arg) : value(value_arg) {}
};

/*! Create an Instance object from which we can extract the value's run-time
 * type.
 *  \param value The const value to be captured.
 */
template <typename T>
inline Instance<T const> instance_of(T const& value) {
  return Instance<T const>(value);
}

/*! Generate a code-string for a type.
 *  \code{.cpp}reflect<float>() --> "float"\endcode
 */
template <typename T>
inline std::string reflect() {
  return detail::ReflectType<T>()();
}

/*! Generate a code-string for a value.
 *  \code{.cpp}reflect(3.14f) --> "(float)3.14"\endcode
 */
template <typename T>
inline std::string reflect(const T& value) {
  return "(" + reflect<T>() + ")" + detail::value_string(value);
}

/*! Generate a code-string for an integer non-type template argument
 *  (via implicit conversion to int64_t).
 *  \code{.cpp}reflect<7>() --> "(int64_t)7"\endcode
 */
template <int64_t N>
inline std::string reflect() {
  return reflect<NonType<int64_t, N>>();
}

/*! Generate a code-string for a generic non-type template argument.
 *  \code{.cpp} reflect<int,7>() --> "(int)7" \endcode
 */
template <typename T, T N>
inline std::string reflect() {
  return reflect<NonType<T, N>>();
}

/*! Generate a code-string for a type wrapped as a Type instance.
 *  \code{.cpp}reflect(Type<float>()) --> "float"\endcode
 */
template <typename T>
inline std::string reflect(Type<T>) {
  return reflect<T>();
}

/*! Generate a code-string for a type wrapped as an Instance instance.
 *  \code{.cpp}reflect(Instance<float>(3.1f)) --> "float"\endcode
 *  or more simply when passed to a instance_of helper
 *  \code{.cpp}reflect(instance_of(3.1f)) --> "float"\endcodei
 *  This is specifically for the case where we want to extract the run-time
 *    type, i.e., derived type, of an object pointer.
 */
template <typename T>
inline std::string reflect(const Instance<T>& value) {
  return detail::get_type_name(typeid(value.value));
}

// TODO: Would there ever be a need to reflect a string literal?
/*! Use an existing code string as-is. */
inline std::string reflect(const std::string& s) { return s; }
/*! Use an existing code string as-is. */
inline const char* reflect(const char* s) { return s; }
#if __cplusplus >= 201703L
/*! Use an existing code string as-is. */
inline std::string_view reflect(std::string_view s) { return s; }
#endif

/*! Create a Type object representing a value's type.
 *  \code{.cpp}type_of(3.14f) -> Type<float>()\endcode
 *  \param [unnamed] The value whose type is to be captured.
 */
template <typename T>
inline Type<T> type_of(T&) {
  return Type<T>();
}

/*! Create a Type object representing a value's type.
 *  \param [unnamed] The const value whose type is to be captured.
 */
template <typename T>
inline Type<T const> type_of(const T&) {
  return Type<T const>();
}

/*! Generate a code-string for a template instantiation. */
inline std::string reflect_template(const StringVec& args) {
  // Note: The space in " >" is a WAR to avoid '>>' appearing
  return jitify2::detail::string_join(args, ",", "<", " >");
}

/*! Generate a code-string for a template instantiation. */
template <typename... Ts>
inline std::string reflect_template() {
  return reflect_template({reflect<Ts>()...});
}

/*! Generate a code-string for a template instantiation. */
template <typename... Args>
inline std::string reflect_template(const Args&... args) {
  return reflect_template({reflect(args)...});
}

/*! Convenience class for generating code-strings for template instantiations.
 */
class Template {
  std::string name_;

 public:
  /*! Construct the class.
   *  \param name The name of the template.
   */
  Template(StringRef name) : name_(name) {}

  /*! Generate a code-string for an instantiation of the template. */
  std::string instantiate(const StringVec& template_args = {}) const {
    return name_ + reflect_template(template_args);
  }

  /*! Generate a code-string for an instantiation of the template. */
  template <typename... TemplateArgs>
  std::string instantiate() const {
    return name_ + reflect_template<TemplateArgs...>();
  }

  /*! Generate a code-string for an instantiation of the template. */
  template <typename... TemplateArgs>
  std::string instantiate(const TemplateArgs&... targs) const {
    return name_ + reflect_template(targs...);
  }
};

}  // namespace reflection

// Simple error type wrapping a string error message.
class ErrorMsg : public std::string {
 public:
  using std::string::string;
  ErrorMsg(const std::string& str) : std::string(str) {}
  ErrorMsg(std::string&& str) : std::string(std::move(str)) {}

  /*! Returns true if the error message is empty. */
  bool ok() const { return this->empty(); }
  /*! Returns true if the error message is non-empty. */
  explicit operator bool() const { return !this->empty(); }
};

namespace detail {

/*! Represents either a value type or an error state.
 *
 *  Access to the underlying value is checked and will throw/terminate if it is
 *  in the error state. The error state can be queried via the ok() method or
 *  operator bool(), and the error data can be accessed via the error() method.
 *  This type has value semantics but provides operator* and operator-> for
 *  accessing the underlying value (similar to std::optional).
 */
template <typename ValueType, typename ErrorType>
class FallibleValue {
 public:
  using value_type = ValueType;
  using error_type = ErrorType;

 private:
  // TODO: Ideally would use std::variant here to avoid storing both.
  // TODO: Consider making this a unique_ptr too to avoid needing default
  // constructors. Only downside is the need to allocate on the heap.
  value_type value_;
  std::unique_ptr<error_type> error_;

 public:
  // Helper type for constructing in error state.
  class Error {
    const error_type& error_;

   public:
    Error(const error_type& error) : error_(error) {}
    const error_type& value() const { return error_; }
  };

  // Default-construct the value in ok state.
  FallibleValue() = default;

  // Construct the value in ok state.
  template <typename... Args>
  explicit FallibleValue(Args&&... args)
      : value_(std::forward<Args>(args)...) {}

  // Construct in error state.
  FallibleValue(Error error) : error_(new error_type(error.value())) {
#if JITIFY_FAIL_IMMEDIATELY
    // Fail now for easier debugging via backtrace.
    if (error.value() != "Uninitialized") {
      JITIFY_THROW_OR_TERMINATE(error.value());
    }
#endif
  }

  // Support copy and assign.
  FallibleValue(const FallibleValue& rhs)
      : value_(rhs.error_ ? value_type() : rhs.value_),
        error_(rhs.error_ ? new error_type(*rhs.error_) : nullptr) {}
  FallibleValue& operator=(const FallibleValue& rhs) {
    value_ = rhs.error_ ? value_type() : rhs.value_;
    error_.reset(rhs.error_ ? new error_type(*rhs.error_) : nullptr);
    return *this;
  }
  FallibleValue(FallibleValue&& rhs) = default;
  FallibleValue& operator=(FallibleValue&& rhs) = default;

  /*! Returns true iff not in error state. */
  bool ok() const noexcept { return !error_; }

  /*! Returns true iff not in error state. */
  explicit operator bool() const noexcept { return ok(); }

  /*! Get the error value. Throws/terminates if not in error state. */
  const error_type& error() const {
    if (ok()) JITIFY_THROW_OR_TERMINATE("Object is not in error state");
    return *error_;
  }

  /*! Get the underlying value. Throws/terminates if in error state. */
  value_type& value() {
    if (!ok()) JITIFY_THROW_OR_TERMINATE(static_cast<std::string>(*error_));
    return value_;
  }
  /*! Get the underlying value. Throws/terminates if in error state. */
  const value_type& value() const {
    if (!ok()) JITIFY_THROW_OR_TERMINATE(static_cast<std::string>(*error_));
    return value_;
  }

  /*! Access the underlying value. Throws/terminates if in error state. */
  value_type* operator->() { return &value(); }
  /*! Access the underlying value. Throws/terminates if in error state. */
  const value_type* operator->() const { return &value(); }
  /*! Access the underlying value. Throws/terminates if in error state. */
  value_type& operator*() { return value(); }
  /*! Access the underlying value. Throws/terminates if in error state. */
  const value_type& operator*() const { return value(); }
};

// Defines common constructors for the user-visible classes.
template <class Subclass, class ValueType, class ErrorType = ErrorMsg>
class FallibleObjectBase : public detail::FallibleValue<ValueType, ErrorType> {
  using super_type = detail::FallibleValue<ValueType, ErrorType>;

 public:
  template <typename... Args>
  explicit FallibleObjectBase(Args&&... args)
      : super_type(std::forward<Args>(args)...) {}

  FallibleObjectBase()
      : FallibleObjectBase(typename super_type::Error("Uninitialized")) {}

  // Allow implicit conversion from Error.
  FallibleObjectBase(typename super_type::Error error)
      : super_type(std::move(error)) {}

  /*! Deserialize the object from a stream.
   *  \return An object containing either a valid XXData object or an error
   *    state.
   */
  static Subclass deserialize(std::istream& stream) {
    ValueType impl;
    if (!ValueType::deserialize(stream, &impl)) {
      return Subclass(typename super_type::Error("Deserialization failed"));
    }
    return Subclass(impl);
  }
  /*! Deserialize the object from a string.
   *  \return An object containing either a valid XXData object or an error
   *    state.
   */
  static Subclass deserialize(StringRef serialized) {
    serialization::imemstream ms(serialized);
    return deserialize(ms);
  }
};

// djb2 algorithm by Dan Bernstein, see http://www.cse.yorku.ca/~oz/hash.html
inline uint64_t hash_value(const char* data, size_t size,
                           uint64_t seed = 5381) {
  uint64_t hash = seed;
  for (size_t i = 0; i < size; ++i) {
    hash = ((hash << 5) + hash) ^ data[i];
  }
  return hash;
}
inline uint64_t hash_combine(uint64_t a, uint64_t b) {
  // Note: The magic number comes from the golden ratio (2**64 / phi).
  return a ^ (0x9E3779B97F4A7C17ull + b + (b >> 2) + (a << 6));
}
template <typename HashType>
inline HashType hash_value(const std::string& s, HashType seed = {}) {
  return hash_value(s.data(), s.size(), seed);
}
template <typename HashType, typename T>
inline HashType hash_value(const std::vector<T>& v, HashType seed = {}) {
  HashType result = seed;
  for (const T& val : v) {
    result = hash_combine(result, hash_value<HashType>(val));
  }
  return result;
}
template <typename HashType, typename Key, typename Val>
inline HashType hash_value(const std::unordered_map<Key, Val>& m,
                           HashType seed = {}) {
  std::vector<Key> keys;
  keys.reserve(m.size());
  for (const auto& key_val : m) {
    keys.push_back(key_val.first);
  }
  std::sort(keys.begin(), keys.end());
  HashType result = seed;
  for (const Key& key : keys) {
    const Val& val = m.at(key);
    result = hash_combine(result, hash_value<HashType>(key));
    result = hash_combine(result, hash_value<HashType>(val));
  }
  return result;
}

// Based on fast-hash. See https://github.com/ztanml/fast-hash
inline uint64_t fasthash64(uint64_t h) {
  h ^= h >> 23;
  h *= 0x2127599bf4325c37ull;
  h ^= h >> 47;
  return h;
}

inline std::string get_cuda_error_string(CUresult ret) {
  const char* error_c;
  cuGetErrorString(ret, &error_c);
  return "CUDA error " + std::to_string(ret) + ": " + error_c;
}

// Returns the sha256 digest as a string of 32 hex digits.
inline std::string sha256(const char* data, size_t size) {
  // This implementation is based on pseudocode from Wikipedia.
  // Initialize array of round constants to first 32 bits of the fractional
  // parts of the cube roots of the first 64 primes 2..311.
  static constexpr uint32_t k[64] = {
      0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1,
      0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
      0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786,
      0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
      0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147,
      0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
      0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
      0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
      0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a,
      0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
      0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2};
  // Initialize hash values to first 32 bits of the fractional parts of the
  // square roots of the first 8 primes 2..19.
  uint32_t h[8] = {0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
                   0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19};
  // Pre-processing with padding.
  std::string padded;
  const auto pad = [](size_t n, size_t mult) {
    return ((n - 1) / mult + 1) * mult;
  };
  static constexpr const unsigned kChunkSize = 64;
  size_t extended_size = size + 1 + sizeof(uint64_t);
  size_t padded_size = pad(extended_size, kChunkSize);
  padded.reserve(padded_size);
  padded.append(data, size);
  // Append a 1 bit (and 7 padding bits).
  padded += static_cast<char>(0x80);
  // Pad with zeros.
  padded.append(padded_size - extended_size, '\0');
  // Append size as a 64-bit big-endian integer.
  const size_t size_bits = size * CHAR_BIT;
  for (int v = 0; v < 64; v += 8) {
    padded += static_cast<char>((size_bits >> (64 - 8 - v)) & 0xFF);
  }
  // Circular shift (rotate) function.
  const auto rotr = [](uint32_t x, unsigned n) {
    return (x >> n) | (x << (32 - n));
  };
  // Process the message in successive 512-bit chunks.
  for (size_t c = 0; c < padded_size; c += kChunkSize) {
    // Create a 64-entry message schedule array.
    uint32_t w[64];
    // Copy chunk into first 16 words (big-endian).
    for (int i = 0; i < 16; ++i) {
      size_t offset = c + i * 4;
      w[i] = (static_cast<uint8_t>(padded[offset + 0]) << 24) |
             (static_cast<uint8_t>(padded[offset + 1]) << 16) |
             (static_cast<uint8_t>(padded[offset + 2]) << 8) |
             (static_cast<uint8_t>(padded[offset + 3]));
    }
    // Extend the first 16 words into the remaining 48 words w[16..63] of the
    // message schedule array.
    for (int i = 16; i < 64; ++i) {
      uint32_t s0 = rotr(w[i - 15], 7) ^ rotr(w[i - 15], 18) ^ (w[i - 15] >> 3);
      uint32_t s1 = rotr(w[i - 2], 17) ^ rotr(w[i - 2], 19) ^ (w[i - 2] >> 10);
      w[i] = w[i - 16] + s0 + w[i - 7] + s1;
    }
    // Initialize working variables to current hash value.
    uint32_t x[8];
    for (int j = 0; j < 8; ++j) {
      x[j] = h[j];
    }
    // Compression function main loop.
    for (int i = 0; i < 64; ++i) {
      uint32_t e = x[4];
      uint32_t S1 = rotr(e, 6) ^ rotr(e, 11) ^ rotr(e, 25);
      uint32_t ch = (e & x[5]) ^ ((~e) & x[6]);
      uint32_t temp1 = x[7] + S1 + ch + k[i] + w[i];
      uint32_t a = x[0];
      uint32_t S0 = rotr(a, 2) ^ rotr(a, 13) ^ rotr(a, 22);
      uint32_t maj = (x[0] & x[1]) ^ (x[0] & x[2]) ^ (x[1] & x[2]);
      uint32_t temp2 = S0 + maj;
      for (int j = 7; j > 0; --j) {
        x[j] = x[j - 1];
      }
      x[4] += temp1;
      x[0] = temp1 + temp2;
    }
    // Add the compressed chunk to the current hash value.
    for (int j = 0; j < 8; ++j) {
      h[j] += x[j];
    }
  }
  // Unpack and render the computed digest as a hex string.
  std::string result;
  result.reserve(32);
  for (unsigned val : h) {
    for (int i = 0; i < 32; i += 4) {
      result += "0123456789ABCDEF"[(val >> (32 - 4 - i)) & 0xF];
    }
  }
  return result;
}

inline std::string sha256(StringRef s) { return sha256(s.data(), s.size()); }

// Normalizes an unmangled CUDA symbol name to match what cu++filt produces.
inline std::string normalize_cuda_symbol_name(const std::string& symbol_name) {
  // Convert "(anonymous namespace)" (c++filt) to "<unnamed>" (cu++filt).
  static const std::regex re_anonymous_namespace(R"(\(anonymous namespace\))",
                                                 std::regex::optimize);
  return std::regex_replace(symbol_name, re_anonymous_namespace, "<unnamed>");
}

}  // namespace detail

// class LoadedProgramData;
class Kernel;

struct CudaModuleDestructor {
  void operator()(CUmodule module) const { cuModuleUnload(module); }
};
using UniqueCudaModule =
    std::unique_ptr<std::remove_pointer<CUmodule>::type, CudaModuleDestructor>;

/*! An object containing a CUDA module that has been loaded into a CUDA context,
 *    along with other metadata.
 */
class LoadedProgramData {
  // We store the members in a shared_ptr so that the object can be cheaply and
  // safely copied, particularly from inside a cache data structure. Note that
  // these members are mostly immutable (the exception being the ability to
  // modify global variables and to set attributes on kernels in the module,
  // which need to be used carefully by the user). Being able to copy this
  // object also avoids Kernel objects needing to store a reference to it, which
  // would present lifetime management issues for the user.
  struct Data {
    UniqueCudaModule module;
    StringMap lowered_name_map;
  };
  std::shared_ptr<Data> data_;

  std::string get_global_ptr_with_size(std::string name,
                                       size_t given_size_bytes,
                                       CUdeviceptr* ptr) const {
    size_t size_bytes;
    std::string error = get_global_ptr(name, ptr, &size_bytes);
    if (!error.empty()) return error;
    if (size_bytes != given_size_bytes) {
      error = std::string("Value for global variable ") + name +
              " has wrong size: got " + std::to_string(given_size_bytes) +
              " bytes, expected " + std::to_string(size_bytes);
      JITIFY_THROW_OR_RETURN(error);
    }
    return {};
  }

 public:
  LoadedProgramData() = default;  // Needed only for FallibleValue constructor
  LoadedProgramData(UniqueCudaModule module, StringMap lowered_name_map = {})
      : data_(new Data{std::move(module), std::move(lowered_name_map)}) {}

  /*! Get the CUDA module of the loaded program. */
  CUmodule module() const { return data_->module.get(); }
  /*! Get the map of name expressions to lowered (mangled) symbol names. */
  const StringMap& lowered_name_map() const { return data_->lowered_name_map; }

  /*! Get a kernel from the loaded program.
   *  \param name The full name of the instantiated kernel (e.g.,
   *    `&quot;my_namespace::my_kernel<123, float>&quot;`).
   *  \return A Kernel object that contains either a valid KernelData object or
   *    an error state.
   */
  Kernel get_kernel(std::string name) const;

  /*! Get the address of a global variable from the loaded program.
   *  \param name The full name of the variable (e.g.,
   *    `&quot;my_namespace::my_variable&quot;`).
   *  \param ptr A pointer to where the result should be stored.
   *  \param size (optional) A pointer to where the size of the variable should
   *    be stored.
   *  \return An empty string on success, otherwise an error message.
   */
  ErrorMsg get_global_ptr(std::string symbol_name, CUdeviceptr* ptr,
                          size_t* size = nullptr) const {
    symbol_name = detail::normalize_cuda_symbol_name(symbol_name);
    auto iter = lowered_name_map().find(symbol_name);
    if (iter != lowered_name_map().end()) {
      symbol_name = iter->second;  // Replace name with lowered name.
    }
    JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(
        cuModuleGetGlobal(ptr, size, module(), symbol_name.c_str()));
    return {};
  }

  /*! Read the data from a global variable in the loaded program.
   *  \param name The full name of the variable (e.g.,
   *    `&quot;my_namespace::my_variable&quot;`).
   *  \param data Pointer to where the resulting data should be wrtten.
   *  \param count The number of elements to read.
   *  \param stream (optional) The CUDA stream to use to transfer the data.
   *  \return An empty string on success, otherwise an error message.
   */
  template <typename T>
  ErrorMsg get_global_data(std::string name, T* data, size_t count,
                           CUstream stream = 0) const {
    size_t size_bytes = count * sizeof(T);
    CUdeviceptr ptr;
    std::string error =
        get_global_ptr_with_size(std::move(name), size_bytes, &ptr);
    if (!error.empty()) return error;
    JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(
        cuMemcpyDtoHAsync(data, ptr, size_bytes, stream));
    return {};
  }

  /*! Write data to a global variable in the loaded program.
   *  \param name The full name of the variable (e.g.,
   *    `&quot;my_namespace::my_variable&quot;`).
   *  \param data Pointer to the data that should be written.
   *  \param count The number of elements to write.
   *  \param stream (optional) The CUDA stream to use to transfer the data.
   *  \return An empty string on success, otherwise an error message.
   *  \warning Though this is a const method, it results in a change of state
   *    that may affect shared references to the program. Care should be taken
   *    when using this from multiple threads.
   */
  template <typename T>
  ErrorMsg set_global_data(std::string name, const T* data, size_t count,
                           CUstream stream = 0) const {
    size_t size_bytes = count * sizeof(T);
    CUdeviceptr ptr;
    std::string error =
        get_global_ptr_with_size(std::move(name), size_bytes, &ptr);
    if (!error.empty()) return error;
    JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(
        cuMemcpyHtoDAsync(ptr, data, size_bytes, stream));
    return {};
  }

  /*! Read the value of a global variable in the loaded program.
   *  \param name The full name of the variable (e.g.,
   *    `&quot;my_namespace::my_variable&quot;`).
   *  \param data Pointer to where the resulting data should be wrtten.
   *  \param stream (optional) The CUDA stream to use to transfer the data.
   *  \return An empty string on success, otherwise an error message.
   */
  template <typename T>
  ErrorMsg get_global_value(std::string name, T* value,
                            CUstream stream = 0) const {
    return get_global_data(std::move(name), value, 1, stream);
  }

  /*! Write a value to a global variable in the loaded program.
   *  \param name The full name of the variable (e.g.,
   *    `&quot;my_namespace::my_variable&quot;`).
   *  \param data Reference to the data that should be written.
   *  \param stream (optional) The CUDA stream to use to transfer the data.
   *  \return An empty string on success, otherwise an error message.
   *  \warning Though this is a const method, it results in a change of state
   *    that may affect shared references to the program. Care should be taken
   *    when using this from multiple threads.
   */
  template <typename T>
  ErrorMsg set_global_value(std::string name, const T& value,
                            CUstream stream = 0) const {
    return set_global_data(std::move(name), &value, 1, stream);
  }
};

class ConfiguredKernel;

// Replacement for dim3 to avoid needing to include the CUDA runtime headers.
struct Dim3 {
  unsigned int x, y, z;
  constexpr Dim3(unsigned int vx = 1, unsigned int vy = 1, unsigned int vz = 1)
      : x(vx), y(vy), z(vz) {}
  template <typename V3,
            typename std::enable_if<
                !std::is_convertible<V3, unsigned int>::value, int>::type = 0>
  constexpr Dim3(const V3& v3) : x(v3.x), y(v3.y), z(v3.z) {}
};

/*! An object containing a loaded CUDA kernel and associated metadata.
 */
class KernelData {
  // We keep a program by value instead of reference to avoid the program object
  // needing to outlive the kernel object. The program uses a shared_ptr
  // internally, so this is cheap.
  LoadedProgramData program_;
  CUfunction function_ = nullptr;
  std::string lowered_name_;

 public:
  KernelData() = default;
  KernelData(LoadedProgramData program, CUfunction function,
             std::string lowered_name = {})
      : program_(std::move(program)),
        function_(function),
        lowered_name_(std::move(lowered_name)) {}

  /*! Get the CUDA function object of the kernel. */
  CUfunction function() const { return function_; }
  /*! Get the lowered (mangled) name of the kernel. */
  const std::string& lowered_name() const { return lowered_name_; }
  /*! Get the program that contains the kernel. */
  const LoadedProgramData& program() const { return program_; }

  /*! Set an attribute of the kernel.
   *  \param attribute The attribute identifier.
   *  \param value The value to set.
   *  \return An empty string on success, otherwise an error message.
   *  \warning Though this is a const method, it results in a change of state
   *    that may affect shared references to the kernel. Care should be taken
   *    when using this from multiple threads.
   */
  ErrorMsg set_attribute(CUfunction_attribute attribute, int value) const {
    JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(
        cuFuncSetAttribute(function_, attribute, value));
    return {};
  }

  /*! Get an attribute of the kernel.
   *  \param attribute The attribute identifier.
   *  \param value Pointer to where the result value should be written.
   *  \return An empty string on success, otherwise an error message.
   */
  ErrorMsg get_attribute(CUfunction_attribute attribute, int* value) const {
    JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(
        cuFuncGetAttribute(value, attribute, function_));
    return {};
  }

  /*! Configure a kernel launch using the provided parameters.
   *  \param grid The grid dimensions for the kernel launch.
   *  \param block The block dimensions for the kernel launch.
   *  \param shared_memory_bytes (optional) The dynamic shared memory to
   *    allocate for the kernel launch (in bytes).
   *  \param stream (optional) The CUDA stream to use for the kernel launch.
   *  \return A ConfiguredKernel object that contains either a valid
   *    ConfiguredKernelData object or an error state.
   */
  ConfiguredKernel configure(Dim3 grid, Dim3 block,
                             unsigned int shared_memory_bytes = 0,
                             CUstream stream = 0) const;
  /*! Configure a kernel launch for maximum occupancy with 1-dimensional grid
   *    and block dimensions.
   *  \param max_block_size (optional) Upper limit on the chosen block size, or
   *    0 for no limit.
   *  \param shared_memory_bytes (optional) The dynamic shared memory to
   *    allocate for the kernel launch (in bytes).
   *  \param shared_memory_bytes_callback (optional) Callback function that
   *    returns the required shared memory size (in bytes) for a given block
   *    size. If provided, this overrides \p shared_memory_bytes.
   *  \param stream (optional) The CUDA stream to use for the kernel launch.
   *  \param flags (optional) Flags to pass to the underlying
   *    cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags API.
   *  \return A ConfiguredKernel object that contains either a valid
   *    ConfiguredKernelData object or an error state.
   */
  ConfiguredKernel configure_1d_max_occupancy(
      int max_block_size = 0, unsigned int shared_memory_bytes = 0,
      CUoccupancyB2DSize shared_memory_bytes_callback = nullptr,
      CUstream stream = 0, unsigned int flags = 0) const;

  // TODO: Add a similar method wrapping
  // cuOccupancyMaxPotentialBlockSizeWithFlags.
};

class Kernel : public detail::FallibleObjectBase<Kernel, KernelData> {
  friend class detail::FallibleObjectBase<Kernel, KernelData>;
  using super_type = detail::FallibleObjectBase<Kernel, KernelData>;
  using super_type::super_type;

 public:
  /*! \see LoadedProgramData::get_kernel */
  static Kernel get_kernel(LoadedProgramData program, std::string name);
};

inline Kernel Kernel::get_kernel(LoadedProgramData program, std::string name) {
  name = detail::normalize_cuda_symbol_name(name);
  auto iter = program.lowered_name_map().find(name);
  if (iter != program.lowered_name_map().end()) {
    name = iter->second;  // Replace name with lowered name.
  }
  CUfunction function;
  CUresult ret = cuModuleGetFunction(&function, program.module(), name.c_str());
  if (ret != CUDA_SUCCESS) {
    return Error("get_kernel with name=\"" + name +
                 "\" failed: " + detail::get_cuda_error_string(ret));
  }
  return Kernel(std::move(program), function, std::move(name));
}

inline Kernel LoadedProgramData::get_kernel(std::string name) const {
  return Kernel::get_kernel(*this, std::move(name));
}

/*! An object containing a configured CUDA kernel and associated metadata.
 */
class ConfiguredKernelData {
  // We keep a kernel by value instead of reference to avoid the kernel object
  // needing to outlive the configured kernel object.
  KernelData kernel_;
  Dim3 grid_;
  Dim3 block_;
  unsigned int shared_memory_bytes_ = 0;
  CUstream stream_ = 0;

 public:
  ConfiguredKernelData() = default;
  ConfiguredKernelData(KernelData kernel, Dim3 grid, Dim3 block,
                       unsigned int shared_memory_bytes = 0,
                       CUstream stream = 0)
      : kernel_(std::move(kernel)),
        grid_(std::move(grid)),
        block_(std::move(block)),
        shared_memory_bytes_(shared_memory_bytes),
        stream_(stream) {}

  /*! Get the underlying kernel object. */
  const KernelData& kernel() const { return kernel_; }
  /*! Get the configured grid dimensions. */
  const Dim3& grid() const { return grid_; }
  /*! Get the configured block dimensions. */
  const Dim3& block() const { return block_; }
  /*! Get the configured dynamic shared memory size in bytes. */
  unsigned int shared_memory_bytes() const { return shared_memory_bytes_; }
  /*! Get the configured CUDA stream. */
  CUstream stream() const { return stream_; }

  /*! Launch the configured kernel.
   *  \param arg_ptrs Array of pointers to kernel arguments.
   *  \return An empty string on success, otherwise an error message.
   */
  ErrorMsg launch(void** arg_ptrs) const {
    JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(cuLaunchKernel(
        // function_, grid_.x, grid_.y, grid_.z, block_.x, block_.y, block_.z,
        kernel_.function(), grid_.x, grid_.y, grid_.z, block_.x, block_.y,
        block_.z, shared_memory_bytes_, stream_, arg_ptrs, nullptr));
    return {};
  }

  /*! Launch the configured kernel.
   *  \param arg_ptrs Vector of pointers to kernel arguments.
   *  \return An empty string on success, otherwise an error message.
   */
  ErrorMsg launch(std::vector<void*>& arg_ptrs) const {
    return launch(arg_ptrs.data());
  }

  /*! Launch the configured kernel.
   *  \param args Arguments for the kernel. Note that reference arguments must
   *    be passed as pointers.
   *  \return An empty string on success, otherwise an error message.
   */
  ErrorMsg launch() const {
    return this->launch(nullptr);
  }
};

class ConfiguredKernel
    : public detail::FallibleObjectBase<ConfiguredKernel,
                                        ConfiguredKernelData> {
  friend class detail::FallibleObjectBase<ConfiguredKernel,
                                          ConfiguredKernelData>;
  using super_type =
      detail::FallibleObjectBase<ConfiguredKernel, ConfiguredKernelData>;
  using super_type::super_type;

 public:
  /*! \see KernelData::configure */
  static ConfiguredKernel configure(KernelData kernel, Dim3 grid, Dim3 block,
                                    unsigned int shared_memory_bytes,
                                    CUstream stream) {
    return ConfiguredKernel(std::move(kernel), std::move(grid),
                            std::move(block), shared_memory_bytes, stream);
  }

  /*! \see KernelData::configure_1d_max_occupancy */
  static ConfiguredKernel configure_1d_max_occupancy(
      KernelData kernel, int max_block_size = 0,
      unsigned int shared_memory_bytes = 0,
      CUoccupancyB2DSize shared_memory_bytes_callback = nullptr,
      CUstream stream = 0, unsigned int flags = 0);
};

inline ConfiguredKernel KernelData::configure(Dim3 grid, Dim3 block,
                                              unsigned int shared_memory_bytes,
                                              CUstream stream) const {
  return ConfiguredKernel::configure(*this, grid, block, shared_memory_bytes,
                                     stream);
}

inline ConfiguredKernel KernelData::configure_1d_max_occupancy(
    int max_block_size, unsigned int shared_memory_bytes,
    CUoccupancyB2DSize shared_memory_bytes_callback, CUstream stream,
    unsigned int flags) const {
  return ConfiguredKernel::configure_1d_max_occupancy(
      *this, max_block_size, shared_memory_bytes, shared_memory_bytes_callback,
      stream, flags);
}

inline ConfiguredKernel ConfiguredKernel::configure_1d_max_occupancy(
    KernelData kernel, int max_block_size, unsigned int shared_memory_bytes,
    CUoccupancyB2DSize shared_memory_bytes_callback, CUstream stream,
    unsigned int flags) {
  int grid, block;
  CUresult ret = cuOccupancyMaxPotentialBlockSizeWithFlags(
      &grid, &block, kernel.function(), shared_memory_bytes_callback,
      shared_memory_bytes, max_block_size, flags);
  if (ret != CUDA_SUCCESS) {
    return Error("Configure failed: " + detail::get_cuda_error_string(ret));
  }
  if (shared_memory_bytes_callback) {
    shared_memory_bytes = (unsigned int)shared_memory_bytes_callback(block);
  }
  return ConfiguredKernel(std::move(kernel), grid, block, shared_memory_bytes,
                          stream);
}

class LoadedProgram
    : public detail::FallibleObjectBase<LoadedProgram, LoadedProgramData> {
  friend class detail::FallibleObjectBase<LoadedProgram, LoadedProgramData>;
  using super_type =
      detail::FallibleObjectBase<LoadedProgram, LoadedProgramData>;
  using super_type::super_type;

 public:
  /*! \see LinkedProgramData::load */
  static LoadedProgram load(StringRef cubin, StringMap lowered_name_map);
};

inline LoadedProgram LoadedProgram::load(StringRef cubin,
                                         StringMap lowered_name_map) {
  CUmodule module;
  CUresult ret = cuModuleLoadData(&module, cubin.data());
  if (ret != CUDA_SUCCESS) {
    return Error("Loading failed: " + detail::get_cuda_error_string(ret));
  }
  return LoadedProgram(UniqueCudaModule(module), std::move(lowered_name_map));
}

/*! An object containing a binary CUBIN string and associated metadata.
 */
class LinkedProgramData
    : public serialization::Serializable<LinkedProgramData> {
  std::string cubin_;
  StringMap lowered_name_map_;
  std::string log_;           // Linker log
  StringVec linker_options_;  // Linker options that were used

  JITIFY_DEFINE_SERIALIZABLE_MEMBERS(LinkedProgramData, cubin_,
                                     lowered_name_map_)

 public:
  LinkedProgramData() = default;
  LinkedProgramData(std::string cubin, StringMap lowered_name_map = {},
                    std::string log = {}, StringVec linker_options = {})
      : cubin_(std::move(cubin)),
        lowered_name_map_(std::move(lowered_name_map)),
        log_(std::move(log)),
        linker_options_(std::move(linker_options)) {}

  /*! Get the binary CUBIN of the linked program. */
  const std::string& cubin() const { return cubin_; }
  /*! Get the map of name expressions to lowered (mangled) symbol names. */
  const StringMap& lowered_name_map() const { return lowered_name_map_; }
  /*! Get the log returned from the linker. */
  const std::string& log() const { return log_; }
  /*! Get the options that were passed to the linker. */
  const StringVec& linker_options() const { return linker_options_; }

  /*! Load the program as a module into the current CUDA context.
   *  \return A LoadedProgram object that contains either a valid
   *    LoadedProgramData object or an error state.
   */
  LoadedProgram load() const {
    return LoadedProgram::load(cubin_, lowered_name_map_);
  }
};

class CompiledProgramData;
class CompiledProgram;

class LinkedProgram
    : public detail::FallibleObjectBase<LinkedProgram, LinkedProgramData> {
  friend class detail::FallibleObjectBase<LinkedProgram, LinkedProgramData>;
  using super_type =
      detail::FallibleObjectBase<LinkedProgram, LinkedProgramData>;
  using super_type::super_type;

 public:
  /*! \see CompiledProgramData::link */
  static LinkedProgram link(const std::string& program,
                            CUjitInputType program_type,
                            StringMap lowered_name_map = {},
                            StringVec options = {});
  /*! Link multiple programs.
   * \note Remaining linker options in each program must match.
   * \see CompiledProgramData::link */
  static LinkedProgram link(size_t num_programs,
                            const CompiledProgramData* compiled_programs[],
                            StringVec options = {});

  static LinkedProgram link(
      const std::vector<const CompiledProgram*>& compiled_programs,
      StringVec options = {});

 private:
  static LinkedProgram link_impl(size_t num_programs,
                                 const std::string* programs[],
                                 const CUjitInputType program_types[],
                                 StringMap lowered_name_map, StringVec options);
};

namespace detail {

using OptionsMap = std::unordered_map<std::string, StringVec>;

// Parses a vector of option strings into a map of key -> values (one value for
// each time the key is repeated in the input). Also strips whitespace
// surrounding keys and values. Returns an error message on failure, otherwise
// an empty string.
// std::string
inline bool parse_options(const StringVec& options, OptionsMap* options_map) {
  for (size_t i = 0; i < options.size(); ++i) {
    std::string option = options[i];
    trim(&option);                       // Strip whitespace
    if (option[0] != '-') return false;  //"Expected an option, got " + option;
    std::string key, val;
    size_t eql = option.find('=');
    if (i + 1 < options.size() && options[i + 1][0] != '-') {
      // Parse "-key" "val".
      key = option;
      val = options[++i];
    } else if (eql != std::string::npos) {
      // Parse "-key=val".
      key = option.substr(0, eql);
      val = option.substr(eql + 1);
    } else if (option.size() > 2 &&
               // HACK: Special case for '-l<lib>' linker flag.
               (std::isupper(option[1]) || option[1] == 'l')) {
      // Parse "-Kval".
      key = option.substr(0, 2);
      val = option.substr(2);
    } else {
      // Parse "-key" (no value).
      key = option;
    }
    trim(&val);  // Strip whitespace
    (*options_map)[key].push_back(std::move(val));
  }
  // return {};
  return true;
}

inline std::string path_base(const std::string& p) {
  // "/usr/local/myfile.dat" -> "/usr/local"
  // "foo/bar"  -> "foo"
  // "foo/bar/" -> "foo/bar"
#if defined _WIN32 || defined _WIN64
  // Note that Windows supports both forward and backslash path separators.
  const char* sep = "\\/";
#else
  char sep = '/';
#endif
  size_t i = p.find_last_of(sep);
  if (i != std::string::npos) {
    return p.substr(0, i);
  } else {
    return "";
  }
}

inline std::string path_join(StringRef p1, StringRef p2) {
#if defined _WIN32 || defined _WIN64
  // Note that Windows supports both forward and backslash path separators.
  const char* sep = "\\/";
#else
  const char* sep = "/";
#endif
  if (p1.size() && p2.size() && std::strchr(sep, p2[0])) {
    return {};  // Error, cannot join to absolute path
  }
  std::string result;
  result.reserve(p1.size() + 1 + p2.size());
  result += p1;
  if (p1.size() && !std::strchr(sep, p1[p1.size() - 1])) {
    result += sep[0];
  }
  result += p2;
  return result;
}

inline const char* get_current_executable_path() {
  static const char* path = []() -> const char* {
    static char buffer[JITIFY_PATH_MAX] = {};
#ifdef __linux__
    if (!::realpath("/proc/self/exe", buffer)) return nullptr;
#elif defined(_WIN32) || defined(_WIN64)
    if (!GetModuleFileNameA(nullptr, buffer, JITIFY_PATH_MAX)) return nullptr;
#endif
    return buffer;
  }();
  return path;
}

inline bool startswith(StringRef str, StringRef prefix) {
  return str.size() >= prefix.size() &&
         std::equal(prefix.begin(), prefix.end(), str.begin());
}

inline bool endswith(StringRef str, StringRef suffix) {
  return str.size() >= suffix.size() &&
         std::equal(suffix.begin(), suffix.end(), str.end() - suffix.size());
}

// Infers the JIT input type from the filename suffix. If no known suffix is
// present, the filename is assumed to refer to a library, and the associated
// suffix (and possibly prefix) is automatically added to the filename.
inline CUjitInputType get_cuda_jit_input_type(std::string* filename) {
  if (endswith(*filename, ".ptx")) {
    return CU_JIT_INPUT_PTX;
  } else if (endswith(*filename, ".cubin")) {
    return CU_JIT_INPUT_CUBIN;
  } else if (endswith(*filename, ".fatbin")) {
    return CU_JIT_INPUT_FATBINARY;
  } else if (endswith(*filename,
#if defined _WIN32 || defined _WIN64
                      ".obj"
#else  // Linux
                      ".o"
#endif
                      )) {
    return CU_JIT_INPUT_OBJECT;
  } else {  // Assume library
#if defined _WIN32 || defined _WIN64
    if (!endswith(*filename, ".lib")) {
      *filename += ".lib";
    }
#else  // Linux
    if (!endswith(*filename, ".a")) {
      *filename = "lib" + *filename + ".a";
    }
#endif
    return CU_JIT_INPUT_LIBRARY;
  }
}

// Note that this appends to *log if it is provided.
inline bool link_programs(size_t num_programs, const std::string* programs[],
                          const CUjitInputType program_types[],
                          const StringVec& options, std::string* error,
                          std::string* log, std::string* linked_cubin) {
#define JITIFY_CHECK_CULINK(call)                                 \
  do {                                                            \
    CUresult jitify_cuda_ret = call;                              \
    if (jitify_cuda_ret != CUDA_SUCCESS) {                        \
      if (error) *error = get_cuda_error_string(jitify_cuda_ret); \
      set_log();                                                  \
      return false;                                               \
    }                                                             \
  } while (0)

  if (num_programs == 0) {
    if (error) *error = "Require at least one program to link";
    return false;
  }
  std::vector<CUjit_option> option_keys;
  std::vector<void*> option_vals;
  OptionsMap options_map;
  if (!parse_options(options, &options_map)) {
    if (error) *error = "Syntax error in linker options";
    return false;
  }
  if (num_programs == 1 && program_types[0] == CU_JIT_INPUT_CUBIN &&
      !options_map.count("-l")) {
    // No linking required, just return the given CUBIN.
    if (linked_cubin) *linked_cubin = *programs[0];
    return true;
  }
#if CUDA_VERSION >= 11040
  for (size_t i = 0; i < num_programs; ++i) {
    if (program_types[i] == CU_JIT_INPUT_NVVM) {
      option_keys.push_back(CU_JIT_LTO);
      option_vals.push_back((void*)1);
      break;
    }
  }
#endif
  StringVec link_files, link_paths;
  for (const auto& key_val : options_map) {
    const std::string& key = key_val.first;
    const StringVec& vals = key_val.second;
    std::string val = !vals.empty() ? vals.back() : "";
    // Note: ptxas actually uses "-g" (lowercase), but we use "-G" to be
    // consistent with NVRTC and NVCC.
    if (key == "-G" || key == "--device-debug") {
      option_keys.push_back(CU_JIT_GENERATE_DEBUG_INFO);
      option_vals.push_back((void*)(intptr_t)1);
      // HACK: Can't allow -lineinfo due to ambiguity with "-l<lib>".
    } else if (/*key == "-lineinfo" ||*/ key == "--generate-line-info") {
      option_keys.push_back(CU_JIT_GENERATE_LINE_INFO);
      option_vals.push_back((void*)(intptr_t)1);
    } else if (key == "-arch" || key == "--gpu-name") {
      if (val.substr(0, 3) != "sm_") {
        if (error) *error = "-arch/--gpu-name value must start with \"sm_\"";
        return false;
      }
      int arch = std::atoi(val.substr(3).c_str());
      option_keys.push_back(CU_JIT_TARGET);
      option_vals.push_back((void*)(intptr_t)arch);
    } else if (key == "-maxrregcount" || key == "--maxrregcount") {
      int max_regs = std::atoi(val.c_str());
      option_keys.push_back(CU_JIT_MAX_REGISTERS);
      option_vals.push_back((void*)(intptr_t)max_regs);
    } else if (key == "-O" || key == "--opt-level") {
      option_keys.push_back(CU_JIT_OPTIMIZATION_LEVEL);
      int opt_level = std::atoi(val.c_str());
      option_vals.push_back((void*)(intptr_t)opt_level);
    } else if (key == "-v" || key == "--verbose") {
      option_keys.push_back(CU_JIT_LOG_VERBOSE);
      option_vals.push_back((void*)(intptr_t)1);
    } else if (key == "-l") {
      link_files = vals;
    } else if (key == "-L") {
      link_paths = vals;
#if CUDA_VERSION >= 11040
      // LTO optimization options.
    } else if (key == "-ftz" || key == "--ftz") {
      option_keys.push_back(CU_JIT_FTZ);
      option_vals.push_back((void*)(intptr_t)1);
    } else if (key == "-prec-div" || key == "--prec-div") {
      option_keys.push_back(CU_JIT_PREC_DIV);
      option_vals.push_back((void*)(intptr_t)1);
    } else if (key == "-prec-sqrt" || key == "--prec-sqrt") {
      option_keys.push_back(CU_JIT_PREC_SQRT);
      option_vals.push_back((void*)(intptr_t)1);
    } else if (key == "-fmad" || key == "--fmad") {
      option_keys.push_back(CU_JIT_FMA);
      option_vals.push_back((void*)(intptr_t)1);
    } else if (key == "-use_fast_math" || key == "--use_fast_math") {
      option_keys.push_back(CU_JIT_FTZ);
      option_vals.push_back((void*)(intptr_t)1);
      option_keys.push_back(CU_JIT_PREC_DIV);
      option_vals.push_back((void*)(intptr_t)1);
      option_keys.push_back(CU_JIT_PREC_SQRT);
      option_vals.push_back((void*)(intptr_t)1);
      option_keys.push_back(CU_JIT_FMA);
      option_vals.push_back((void*)(intptr_t)1);
#endif
    } else {
      if (error) *error = "Unknown option: " + key;
      return false;
    }
  }
  constexpr const long kLogSize = 8192;
  char info_log[kLogSize];
  char error_log[kLogSize];
  if (log) {
    option_keys.push_back(CU_JIT_INFO_LOG_BUFFER);
    option_vals.push_back((void*)info_log);
    option_keys.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
    option_vals.push_back((void*)(long)kLogSize);
    option_keys.push_back(CU_JIT_ERROR_LOG_BUFFER);
    option_vals.push_back((void*)error_log);
    option_keys.push_back(CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES);
    option_vals.push_back((void*)(long)kLogSize);
  }
  auto set_log = [&]() {
    if (log) {
      size_t info_log_size = std::strlen(info_log);
      size_t error_log_size = std::strlen(error_log);
      log->reserve(log->size() + info_log_size + 1 + error_log_size);
      log->append(info_log, info_log + info_log_size);
      *log += '\n';
      log->append(error_log, error_log + error_log_size);
    }
  };

  CUlinkState culink_state;
  JITIFY_CHECK_CULINK(cuLinkCreate((unsigned)option_keys.size(),
                                   option_keys.data(), option_vals.data(),
                                   &culink_state));
  struct ScopedCULinkStateDestroyer {
    CUlinkState& culink_state_;
    ~ScopedCULinkStateDestroyer() { cuLinkDestroy(culink_state_); }
  } culink_state_scope_guard{culink_state};

  for (size_t i = 0; i < num_programs; ++i) {
    JITIFY_CHECK_CULINK(cuLinkAddData(
        culink_state, program_types[i], (void*)programs[i]->data(),
        programs[i]->size(), "jitified_source", 0, 0, 0));
  }

  for (std::string link_file : link_files) {
    CUjitInputType jit_input_type;
    if (link_file == ".") {
      // Special case for linking to current executable.
      link_file = get_current_executable_path();
      jit_input_type = CU_JIT_INPUT_OBJECT;
    } else {
      // Infer based on filename.
      jit_input_type = get_cuda_jit_input_type(&link_file);
    }
    CUresult result =
        cuLinkAddFile(culink_state, jit_input_type, link_file.c_str(), 0, 0, 0);
    int path_num = 0;
    while (result == CUDA_ERROR_FILE_NOT_FOUND &&
           path_num < (int)link_paths.size()) {
      std::string filename = path_join(link_paths[path_num++], link_file);
      result = cuLinkAddFile(culink_state, jit_input_type, filename.c_str(), 0,
                             0, 0);
    }
    if (log) {
      if (result == CUDA_ERROR_FILE_NOT_FOUND) {
        log->append("Linker error: Device library not found: ");
        log->append(link_file);
      } else if (result != CUDA_SUCCESS) {
        log->append("Linker error: Failed to add file: ");
        log->append(link_file);
      }
    }
    JITIFY_CHECK_CULINK(result);
  }

  size_t cubin_size;
  void* cubin_ptr;
  JITIFY_CHECK_CULINK(cuLinkComplete(culink_state, &cubin_ptr, &cubin_size));
  set_log();
  if (linked_cubin) {
    linked_cubin->assign((char*)cubin_ptr, (char*)cubin_ptr + cubin_size);
  }
#undef JITIFY_CHECK_CULINK
  return true;
}

}  // namespace detail

/*! An object containing a PTX (and maybe CUBIN) source string and associated
 *  metadata.
 */
class CompiledProgramData
    : public serialization::Serializable<CompiledProgramData> {
  std::string ptx_;
  std::string cubin_;  // Only available with NVRTC version >= 11.2
  std::string nvvm_;   // Only available with NVRTC version >= 11.4
  // Maps name expressions to lowered symbol names (aka. unmangled to mangled).
  StringMap lowered_name_map_;
  StringVec remaining_linker_options_;  // Passed on to LinkedProgram::link.
  std::string log_;                     // Compilation log
  StringVec compiler_options_;          // Compiler options that were used.

  JITIFY_DEFINE_SERIALIZABLE_MEMBERS(CompiledProgramData, ptx_, cubin_, nvvm_,
                                     lowered_name_map_,
                                     remaining_linker_options_)

 public:
  CompiledProgramData() = default;
  CompiledProgramData(std::string ptx, std::string cubin = {},
                      std::string nvvm = {}, StringMap lowered_name_map = {},
                      StringVec linker_options = {}, std::string log = {},
                      StringVec compiler_options = {})
      : ptx_(std::move(ptx)),
        cubin_(std::move(cubin)),
        nvvm_(std::move(nvvm)),
        lowered_name_map_(std::move(lowered_name_map)),
        remaining_linker_options_(std::move(linker_options)),
        log_(std::move(log)),
        compiler_options_(std::move(compiler_options)) {}

  /*! Get the PTX source of the compiled program. */
  const std::string& ptx() const { return ptx_; }
  /*! Get the CUBIN binary of the compiled program.
   * \note The CUBIN is only available here with NVRTC version >= 11.2; older
   * versions will return an empty string. The linked CUBIN is always available
   * from LinkedProgramData::cubin.
   */
  const std::string& cubin() const { return cubin_; }
  /*! Get the NVVM IR of the compiled program.
   * \note The NVVM is only available here with NVRTC version >= 11.4 and the
   * "-dlto" compiler option.
   */
  const std::string& nvvm() const { return nvvm_; }
  /*! Get the map of name expressions to lowered (mangled) symbol names. */
  const StringMap& lowered_name_map() const { return lowered_name_map_; }
  /*! Get the remaining options that will be passed on to the compiler. */
  const StringVec& remaining_linker_options() const {
    return remaining_linker_options_;
  }
  /*! Get the log returned from the compiler. */
  const std::string& log() const { return log_; }
  /*! Get the options that were passed to the compiler. */
  const StringVec& compiler_options() const { return compiler_options_; }

  /*! Link the program into a binary CUBIN object.
   *  \param extra_linker_options List of additional linker options.
   *  \return A LinkedProgram object that contains either a valid
   *    LinkedProgramData object or an error state.
   */
  LinkedProgram link(StringVec extra_linker_options = {}) const {
    const CompiledProgramData* compiled_programs[] = {this};
    return LinkedProgram::link(1, compiled_programs,
                               std::move(extra_linker_options));
  }
};

class CompiledProgram
    : public detail::FallibleObjectBase<CompiledProgram, CompiledProgramData> {
  friend class detail::FallibleObjectBase<CompiledProgram, CompiledProgramData>;
  using super_type =
      detail::FallibleObjectBase<CompiledProgram, CompiledProgramData>;
  using super_type::super_type;

 public:
  // Returns either a valid program or an error state.
  /*! \see PreprocessedProgramData::compile */
  static CompiledProgram compile(const std::string& name,
                                 const std::string& source,
                                 const StringMap& header_sources = {},
                                 const StringVec& name_expressions = {},
                                 StringVec compiler_options = {},
                                 StringVec linker_options = {});

  /*! \see PreprocessedProgramData::compile */
  static CompiledProgram compile(const std::string& name,
                                 const std::string& source,
                                 const StringMap& header_sources = {},
                                 const std::string& name_expression = {},
                                 StringVec compiler_options = {},
                                 StringVec linker_options = {}) {
    return compile(name, source, header_sources, StringVec({name_expression}),
                   std::move(compiler_options), std::move(linker_options));
  }
};

inline LinkedProgram LinkedProgram::link(
    size_t num_programs, const CompiledProgramData* compiled_programs[],
    StringVec options) {
  if (num_programs == 0) return Error("Must have at least one program to link");
  const StringVec& prog_linker_options =
      compiled_programs[0]->remaining_linker_options();
  StringMap lowered_name_map = compiled_programs[0]->lowered_name_map();
  size_t total_lowered_names = lowered_name_map.size();
  for (size_t i = 1; i < num_programs; ++i) {
    if (compiled_programs[i]->remaining_linker_options() !=
        prog_linker_options) {
      return Error("Program linker options must match");
    }
    total_lowered_names += compiled_programs[i]->lowered_name_map().size();
  }
  options.insert(options.begin(), prog_linker_options.begin(),
                 prog_linker_options.end());
  lowered_name_map.reserve(total_lowered_names);
  for (size_t i = 1; i < num_programs; ++i) {
    lowered_name_map.insert(compiled_programs[i]->lowered_name_map().begin(),
                            compiled_programs[i]->lowered_name_map().end());
  }
  std::vector<const std::string*> programs;
  std::vector<CUjitInputType> program_types;
  programs.reserve(num_programs);
  program_types.reserve(num_programs);
  for (size_t i = 0; i < num_programs; ++i) {
    const CompiledProgramData& compiled_program = *compiled_programs[i];
    int cuda_driver_version;
    cuDriverGetVersion(&cuda_driver_version);
    if (std::min(CUDA_VERSION, cuda_driver_version) < 11040 &&
        !compiled_program.nvvm().empty()) {
      return Error("Linking NVVM IR is not supported with CUDA < 11.4");
    }
    const std::string& program = !compiled_program.nvvm().empty()
                                     ? compiled_program.nvvm()
                                     : !compiled_program.cubin().empty()
                                           ? compiled_program.cubin()
                                           : compiled_program.ptx();
    CUjitInputType program_type =
#if CUDA_VERSION >= 11040
        !compiled_program.nvvm().empty() ? CU_JIT_INPUT_NVVM :
#endif
                                         !compiled_program.cubin().empty()
                                             ? CU_JIT_INPUT_CUBIN
                                             : CU_JIT_INPUT_PTX;
    programs.emplace_back(&program);
    program_types.emplace_back(program_type);
  }
  return link_impl(num_programs, programs.data(), program_types.data(),
                   std::move(lowered_name_map), std::move(options));
}

inline LinkedProgram LinkedProgram::link(
    const std::vector<const CompiledProgram*>& compiled_programs,
    StringVec options) {
  std::vector<const CompiledProgramData*> prog_ptrs;
  prog_ptrs.reserve(compiled_programs.size());
  for (const CompiledProgram* compiled_program_ptr : compiled_programs) {
    const CompiledProgram& compiled_program = *compiled_program_ptr;
    if (!compiled_program) return Error(compiled_program.error());
    prog_ptrs.emplace_back(&*compiled_program);
  }
  return link(compiled_programs.size(), prog_ptrs.data(), std::move(options));
}

inline LinkedProgram LinkedProgram::link(const std::string& program,
                                         CUjitInputType program_type,
                                         StringMap lowered_name_map,
                                         StringVec options) {
  const std::string* programs[] = {&program};
  return link_impl(1, programs, &program_type, std::move(lowered_name_map),
                   std::move(options));
}

inline LinkedProgram LinkedProgram::link_impl(
    size_t num_programs, const std::string* programs[],
    const CUjitInputType program_types[], StringMap lowered_name_map,
    StringVec options) {
  std::string error, log, linked_cubin;
  log = detail::string_join(options, " ", "Linker options: \"", "\"\n");
  if (!detail::link_programs(num_programs, programs, program_types, options,
                             &error, &log, &linked_cubin)) {
    return Error("Linking failed: " + error + '\n' + log);
  }
  return LinkedProgram(std::move(linked_cubin), std::move(lowered_name_map),
                       std::move(log), std::move(options));
}

namespace detail {

template <typename ResultType, typename... Args>
using function_type = ResultType(Args...);

#if !JITIFY_LINK_NVRTC_STATIC
class DynamicLibrary {
  using handle_type =
#if defined(_WIN32) || defined(_WIN64)
      HMODULE;
#else
      void*;
#endif

 private:
  struct Deleter {
    void operator()(handle_type handle) const {
      if (handle) {
#if defined(_WIN32) || defined(_WIN64)
        ::FreeLibrary(handle);
#else
        ::dlclose(handle);
#endif
      }
    }
  };

  std::unique_ptr<std::remove_pointer<handle_type>::type, Deleter> lib_;
  std::string error_;

 public:
  DynamicLibrary() = default;
  DynamicLibrary(const char* name) { open(name); }

  bool open(const char* name) {
    error_.clear();
#if defined(_WIN32) || defined(_WIN64)
    lib_.reset(::LoadLibraryA(name));
    if (!lib_) {
      DWORD error_code = ::GetLastError();
      LPSTR buffer = nullptr;
      size_t size = ::FormatMessageA(
          FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
              FORMAT_MESSAGE_IGNORE_INSERTS,
          NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
          (LPSTR)&buffer, 0, NULL);
      error_ = std::string(buffer, size);
      ::LocalFree(buffer);
      return false;
    }
#else
    ::dlerror();  // Clear any existing error
    lib_.reset(::dlopen(name, RTLD_LAZY));
    if (!lib_) {
      error_ = ::dlerror();
      return false;
    }
#endif
    return true;
  }

  void close() { lib_.reset(); }

  explicit operator bool() const { return static_cast<bool>(lib_); }
  const std::string& error() const { return error_; }

  template <typename ResultType, typename... Args>
  function_type<ResultType, Args...>* function(const char* func_name) const {
    auto* func =
#if defined(_WIN32) || defined(_WIN64)
        ::GetProcAddress(lib_.get(), func_name);
#else
        ::dlsym(lib_.get(), func_name);
#endif
    return reinterpret_cast<function_type<ResultType, Args...>*>(func);
  }
};
#endif  // !JITIFY_LINK_NVRTC_STATIC

}  // namespace detail

class LibNvrtc
#if !JITIFY_LINK_NVRTC_STATIC
    : public detail::DynamicLibrary
#endif
{
 public:
  LibNvrtc() {
#if !JITIFY_LINK_NVRTC_STATIC
    int compiled_major = CUDA_VERSION / 1000;
    std::string major_str = std::to_string(compiled_major);
    // Try to load the major-versioned-only file.
    std::string libname =
#if defined(_WIN32) || defined(_WIN64)
        "nvrtc64_" + major_str + ".dll";
#else
        "libnvrtc.so." + major_str;
#endif
    if (!this->open(libname.c_str())) {
      // Fall back to a brute-force search over minor versions.
      for (int minor = 9; minor >= 0; --minor) {
#if defined(_WIN32) || defined(_WIN64)
        // TODO: Why does the filename have _0 on the end (not in docs)?
        libname = "nvrtc64_" + major_str + std::to_string(minor) + "_0.dll";
#else
        libname = "libnvrtc.so." + major_str + "." + std::to_string(minor);
#endif
        if (this->open(libname.c_str())) break;
      }
    }
#endif  // !JITIFY_LINK_NVRTC_STATIC
  }

#if JITIFY_LINK_NVRTC_STATIC
  operator bool() { return true; }
  const std::string& error() const {
    static std::string err;
    return err;
  }
#define JITIFY_DEFINE_NVRTC_WRAPPER(name, result_type, ...)       \
  detail::function_type<result_type, __VA_ARGS__>* name() const { \
    return &nvrtc##name;                                          \
  }
#else  // dynamic linking
#define JITIFY_DEFINE_NVRTC_WRAPPER(name, result_type, ...)       \
  detail::function_type<result_type, __VA_ARGS__>* name() const { \
    static const auto func =                                      \
        this->function<result_type, __VA_ARGS__>("nvrtc" #name);  \
    return func;                                                  \
  }
#endif
  JITIFY_DEFINE_NVRTC_WRAPPER(AddNameExpression, nvrtcResult, nvrtcProgram,
                              const char* const)
  JITIFY_DEFINE_NVRTC_WRAPPER(CompileProgram, nvrtcResult, nvrtcProgram, int,
                              const char* const*)
  JITIFY_DEFINE_NVRTC_WRAPPER(CreateProgram, nvrtcResult, nvrtcProgram*,
                              const char*, const char*, int, const char* const*,
                              const char* const*)
  JITIFY_DEFINE_NVRTC_WRAPPER(DestroyProgram, nvrtcResult, nvrtcProgram*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetLoweredName, nvrtcResult, nvrtcProgram,
                              const char* const, const char**)
#if JITIFY_LINK_NVRTC_STATIC && CUDA_VERSION < 11010
  detail::function_type<nvrtcResult, nvrtcProgram, char*>* GetCUBIN() {
    return nullptr;
  }
  detail::function_type<nvrtcResult, nvrtcProgram, size_t*>* GetCUBINSize() {
    return nullptr;
  }
#else
  JITIFY_DEFINE_NVRTC_WRAPPER(GetCUBIN, nvrtcResult, nvrtcProgram, char*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetCUBINSize, nvrtcResult, nvrtcProgram, size_t*)
#endif
#if JITIFY_LINK_NVRTC_STATIC && CUDA_VERSION < 11020
  detail::function_type<nvrtcResult, nvrtcProgram, int*>*
  GetNumSupportedArchs() {
    return nullptr;
  }
  detail::function_type<nvrtcResult, nvrtcProgram, int*>* GetSupportedArchs() {
    return nullptr;
  }
#else
  JITIFY_DEFINE_NVRTC_WRAPPER(GetNumSupportedArchs, nvrtcResult, int*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetSupportedArchs, nvrtcResult, int*)
#endif
#if JITIFY_LINK_NVRTC_STATIC && CUDA_VERSION < 11040
  detail::function_type<nvrtcResult, nvrtcProgram, char*>* GetNVVM() {
    return nullptr;
  }
  detail::function_type<nvrtcResult, nvrtcProgram, size_t*>* GetNVVMSize() {
    return nullptr;
  }
#else
  JITIFY_DEFINE_NVRTC_WRAPPER(GetNVVM, nvrtcResult, nvrtcProgram, char*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetNVVMSize, nvrtcResult, nvrtcProgram, size_t*)
#endif
  JITIFY_DEFINE_NVRTC_WRAPPER(GetErrorString, const char*, nvrtcResult)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetPTX, nvrtcResult, nvrtcProgram, char*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetPTXSize, nvrtcResult, nvrtcProgram, size_t*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetProgramLog, nvrtcResult, nvrtcProgram, char*)
  JITIFY_DEFINE_NVRTC_WRAPPER(GetProgramLogSize, nvrtcResult, nvrtcProgram,
                              size_t*)
  JITIFY_DEFINE_NVRTC_WRAPPER(Version, nvrtcResult, int*, int*)
#undef JITIFY_DEFINE_NVRTC_WRAPPER

  // Returns the runtime NVRTC version the same format as CUDA_VERSION.
  int get_version() const {
    static const int version = [this] {
      int major, minor;
      Version()(&major, &minor);
      return major * 1000 + minor * 10;
    }();
    return version;
  }
};

inline LibNvrtc& nvrtc() {
  static LibNvrtc lib;
  return lib;
}

namespace detail {

// Parses and removes the (first) architecture flag from the given vector of
// options.
// Returns 0 on failure or if no architecture option is found.
// Sets *error on failure (if provided).
// Returns -1 if the arch value is the special string "compute_." or "sm_.".
// Otherwise returns the integer arch value.
// On success, sets *is_virtual to true if a "compute_" value was found, or
// false for an "arch_" value, and *beg_idx and *end_idx are set to the range
// of the arch entries in the options vector (e.g., so that they can be erased
// by the caller).
inline int parse_arch_flag(const StringVec& options, bool* is_virtual,
                           std::string* error = nullptr,
                           size_t* beg_idx = nullptr,
                           size_t* end_idx = nullptr) {
  for (int i = 0; i < (int)options.size(); ++i) {
    StringRef option = ltrim(options[i]);
    size_t key_end = 0;
    if (startswith(option, "-arch")) {
      key_end = std::strlen("-arch");
    } else if (startswith(option, "--gpu-architecture")) {
      key_end = std::strlen("--gpu-architecture");
    } else if (startswith(option, "--gpu-name")) {  // ptxas flag name
      key_end = std::strlen("--gpu-name");
    } else {
      continue;
    }
    size_t eql_beg = option.find("=", key_end);
    StringSlice value;
    bool found_eql = eql_beg != std::string::npos;
    if (found_eql) {
      value = trim(option.substr(eql_beg + 1));
    } else {
      if (i + 1 == (int)options.size()) {
        if (error) *error = "Expected value after option.";
        return 0;
      }
      value = trim(options[i + 1]);
    }
    if (startswith(value, "compute_")) {
      *is_virtual = true;
      value = value.substr(std::strlen("compute_"));
    } else if (startswith(value, "sm_")) {
      *is_virtual = false;
      value = value.substr(std::strlen("sm_"));
    } else {
      if (error) *error = "Expected value to begin with 'compute_' or 'sm_'.";
      return 0;
    }
    int result;
    if (value == ".") {
      result = -1;
    } else {
      int cc = std::atoi(std::string(value).c_str());
      if (cc == 0) {
        if (error) *error = "Failed to parse a valid architecture number.";
        return 0;
      }
      result = cc;
    }
    // Store iterators to the arch entries in options.
    if (beg_idx) *beg_idx = i;
    if (end_idx) *end_idx = i + (1 + !found_eql);
    return result;
  }
  if (beg_idx) *beg_idx = 0;
  if (end_idx) *end_idx = 0;
  return 0;  // No architecture option found
}

// Returns 0 on failure and sets *error if provided. Otherwise returns a compute
// capability such as 61 for sm_61.
inline int get_current_device_compute_capability(std::string* error = nullptr) {
  CUdevice device;
  int cc_major, cc_minor;
  CUresult ret;
  if ((ret = cuCtxGetDevice(&device)) != CUDA_SUCCESS ||
      (ret = cuDeviceGetAttribute(
           &cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)) !=
          CUDA_SUCCESS ||
      (ret = cuDeviceGetAttribute(
           &cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)) !=
          CUDA_SUCCESS) {
    if (error) *error = get_cuda_error_string(ret);
    return 0;
  }
  int cc = cc_major * 10 + cc_minor;
  return cc;
}

// Returns 0 on failure and sets *error if provided. Otherwise returns a compute
// capability that is supported by the current version of NVRTC.
inline int limit_to_supported_compute_capability(int cc,
                                                 std::string* error = nullptr) {
  // Note: We limit virtual architectures to the max supported by the current
  // version of NVRTC to avoid errors when using older versions of NVRTC with
  // newer hardware+driver. Forward compatibility of PTX allows this to work.
  // Tegra chips do not have forwards compatibility so we need to special case
  // them.
  // TODO: It would be better to detect these somehow, rather than hard-coding.
  bool is_tegra = (cc == 32 ||  // Logan
                   cc == 53 ||  // Erista
                   cc == 62 ||  // Parker
                   cc == 72);   // Xavier
  if (is_tegra) return cc;

  if (nvrtc() && nvrtc().GetSupportedArchs()) {
    static const int max_supported_arch = [] {
      int num_supported_archs;
      nvrtcResult nvrtc_ret =
          nvrtc().GetNumSupportedArchs()(&num_supported_archs);
      if (nvrtc_ret != NVRTC_SUCCESS) return 0;
      std::vector<int> supported_archs(num_supported_archs);
      nvrtc_ret = nvrtc().GetSupportedArchs()(supported_archs.data());
      if (nvrtc_ret != NVRTC_SUCCESS) return 0;
      return supported_archs.back();
    }();
    cc = std::min(cc, max_supported_arch);
  } else {
    // Ensure that future CUDA versions just work (even if suboptimal).
    const int cuda_major = std::min(11, CUDA_VERSION / 1000);
    // clang-format off
    switch (cuda_major) {
      case 11: cc = std::min(cc, 80); break; // Ampere
      case 10: cc = std::min(cc, 75); break; // Turing
      case  9: cc = std::min(cc, 70); break; // Volta
      case  8: cc = std::min(cc, 61); break; // Pascal
      case  7: cc = std::min(cc, 52); break; // Maxwell
      default:
        if (error) *error = "Unsupported CUDA version";
        return 0;
    }
    // clang-format on
  }
  return cc;
}

// Parses compiler_options and applies automatic architecture detection if
// necessary, filling in the architecture flag in both compiler_options and
// linker_options.
// Returns false on failure and sets *error if provided.
inline bool process_architecture_flags(StringVec* compiler_options,
                                       StringVec* linker_options,
                                       std::string* error_ptr = nullptr) {
  std::string error;
  auto check_error = [&]() {
    if (!error.empty()) {
      if (error_ptr) *error_ptr = error;
      return false;
    }
    return true;
  };
  bool is_virtual;
  // First identify any existing real arch in linker_options (e.g., from a
  // previous call to this function).
  int linker_cc = parse_arch_flag(*linker_options, &is_virtual, &error);
  if (!check_error()) return false;
  if (linker_cc < 0) {
    // We do not allow "-arch=sm_." to be given as a linker option.
    if (error_ptr) {
      *error_ptr = "Linker architecture must be explicit if provided.";
    }
    return false;
  }
  if (linker_cc > 0 && is_virtual) {
    if (error_ptr) {
      *error_ptr = "Linker architecture flag must be sm_ not compute_.";
    }
    return false;
  }
  // Now parse compiler options.
  size_t beg_idx, end_idx;
  int given_cc = parse_arch_flag(*compiler_options, &is_virtual, &error,
                                 &beg_idx, &end_idx);
  if (!check_error()) return false;
  // Remove the parsed arch flag entries; they are replaced below.
  compiler_options->erase(compiler_options->begin() + beg_idx,
                          compiler_options->begin() + end_idx);
  int real_cc;
  if (linker_cc != 0) {
    real_cc = linker_cc;
  } else if (given_cc > 0 && !is_virtual) {
    real_cc = given_cc;
  } else {
    real_cc = get_current_device_compute_capability(&error);
    if (!check_error()) return false;
  }
  int virt_cc;
  if (!given_cc) {
    // No arch flag was given. Detect the real arch and use a supported
    // virtual arch for the compiler.
    virt_cc = limit_to_supported_compute_capability(real_cc, &error);
    if (!check_error()) return false;
  } else if (is_virtual) {
    // A virtual arch flag was given. Detect the real arch and convert it to a
    // supported virtual arch for the compiler if one was not specified.
    if (given_cc != -1) {
      virt_cc = given_cc;
    } else {
      virt_cc = limit_to_supported_compute_capability(real_cc, &error);
      if (!check_error()) return false;
    }
  } else {
    // A real arch flag was given. Detect the real arch if it was not specified,
    // and use either the real or a supported virtual arch for the compiler
    // depending on the NVRTC version.
    if (!nvrtc()) {
      if (error_ptr) *error_ptr = nvrtc().error();
      return false;
    }
    int supported_real_cc =
        limit_to_supported_compute_capability(real_cc, &error);
    if (!check_error()) return false;
    if (!nvrtc().GetCUBIN() || supported_real_cc != real_cc) {
      // This NVRTC version does not support compiling to a/the real arch.
      virt_cc = supported_real_cc;
    } else {
      // Pass the real arch to NVRTC.
      virt_cc = 0;
    }
  }
  // Add the computed arch flag back to the compiler options and to the linker
  // options.
  if (virt_cc) {
    compiler_options->push_back("-arch=compute_" + std::to_string(virt_cc));
  } else {
    compiler_options->push_back("-arch=sm_" + std::to_string(real_cc));
  }
  if (linker_cc == 0) {
    linker_options->push_back("-arch=sm_" + std::to_string(real_cc));
  }
  return true;
}

inline void add_std_flag_if_not_specified(StringVec* options,
                                          std::string value = "c++11") {
  for (const std::string& option : *options) {
    if (option.find("--std") != std::string::npos ||
        option.find("-std") != std::string::npos) {
      // A standard was explicitly specified, don't change anything.
      return;
    }
  }
  // Jitify must be compiled with C++11 support, so we default to enabling it
  // for the JIT-compiled code too.
  options->push_back("-std=" + value);
}

inline void add_default_device_flag_if_not_specified(StringVec* options) {
  for (const std::string& option : *options) {
    if (option.find("--device-as-default-execution-space") !=
            std::string::npos ||
        option.find("-default-device") != std::string::npos) {
      return;  // Already specified, do nothing.
    }
  }
  options->push_back("-default-device");
}

inline bool pop_flag(StringVec* options, const std::string& short_flag,
                     const std::string& long_flag) {
  auto it = std::remove_if(options->begin(), options->end(),
                           [&](const std::string& opt) {
                             return opt == short_flag || opt == long_flag;
                           });
  if (it != options->end()) {
    options->resize(it - options->begin());
    return true;
  }
  return false;
}

// Demangles nested variable names using the PTX name mangling scheme
// (which mostly follows the Itanium64 ABI). E.g., _ZN1a3Foo2bcE -> a::Foo::bc.
inline std::string demangle_ptx_variable_name(const char* mangled_name) {
#if CUDA_VERSION >= 11040 && JITIFY_USE_LIBCUFILT
  size_t bufsize = 0;
  char* buf = nullptr;
  int status;
  auto demangled_ptr = std::unique_ptr<char, void (*)(void*)>(
      __cu_demangle(mangled_name, buf, &bufsize, &status), std::free);
  // clang-format off
  switch (status) {
  case 0: return demangled_ptr.get();  // Demangled successfully
  case -2: return mangled_name;        // Interpret as plain unmangled name
  case -1: // fall-through             // Memory allocation failure
  case -3: // fall-through             // Invalid argument
  default: return "";
  }
    // clang-format on
#else
  std::stringstream ss;
  const char* c = mangled_name;
  if (*c++ != '_' || *c++ != 'Z') return mangled_name;  // Non-mangled name
  if (*c++ != 'N') return "";  // Not a nested name, unsupported
  while (true) {
    // Parse identifier length.
    int n = 0;
    while (std::isdigit(*c)) {
      n = n * 10 + (*c - '0');
      c++;
    }
    if (!n) return "";  // Invalid or unsupported mangled name
    // Parse identifier.
    const char* c0 = c;
    while (n-- && *c) c++;
    if (!*c) return "";  // Mangled name is truncated
    std::string id(c0, c);
    if (id.substr(0, 7) == "_GLOBAL") {
      // Identifiers starting with "_GLOBAL" are anonymous namespaces.
      // Note: c++filt gives "(anonymous namespace)" instead of "<unnamed>", but
      // we use the latter to match cu++filt.
      ss << "<unnamed>";
    } else if (id.substr(0, 10) == "_INTERNAL_") {
      // Identifiers starting with "_INTERNAL" represent internal linkage and
      // are replaced with the program name (which is embedded in them).
      // (These appear as of CUDA >=11.3).
      char* program_name;
      long program_name_len = std::strtol(id.c_str() + 10, &program_name, 10);
      if (!program_name_len) return "";  // Note: Program name is never empty
      ++program_name;                    // Skip a '_' that follows the length
      ss << StringSlice(program_name, program_name_len);
    } else {
      ss << id;
    }
    // Nested name specifiers end with 'E'.
    if (*c == 'E') break;
    // There are more identifiers to come, add join token.
    ss << "::";
  }
  return ss.str();
#endif
}

// Finds global __constant__ and __device__ variable declarations in ptx,
// demangles their lowered names, and adds them to *lowered_name_map.
// Note that this does not support template variables (they will be ignored).
inline void find_lowered_global_variables(StringRef ptx,
                                          StringMap* lowered_name_map) {
  size_t pos = 0;
  while (pos < ptx.size()) {
    pos = std::min(ptx.find(".const .align", pos),
                   ptx.find(".global .align", pos));
    if (pos == std::string::npos) break;
    size_t end = ptx.find_first_of(";=", pos);
    if (ptx[end] == '=') --end;
    StringRef line = ptx.substr(pos, end - pos);
    pos = end;
    size_t symbol_start = line.find_last_of(" ") + 1;
    size_t symbol_end = line.find_last_of("[");
    std::string entry(line.substr(symbol_start, symbol_end - symbol_start));
    std::string key = demangle_ptx_variable_name(entry.c_str());
    // Skip unsupported mangled names. E.g., a static variable defined inside
    // a function (such variables are not directly addressable from outside
    // the function, so skipping them is the correct behavior).
    if (key == "") continue;
    lowered_name_map->emplace(key, entry);
  }
}

// Returns false on error.
// Sets *error on failure if provided.
// Sets *log if provided.
// Sets *ptx on success if provided.
// Adds one entry to *lowered_name_map for each entry in name_expressions.
inline bool compile_program(
    const std::string& name, const std::string& source,
    const StringMap& header_sources, const StringVec& options,
    std::string* error = nullptr, std::string* log = nullptr,
    std::string* ptx = nullptr, std::string* cubin = nullptr,
    std::string* nvvm = nullptr, const StringVec& name_expressions = {},
    StringMap* lowered_name_map = nullptr) {
  if (!nvrtc()) {
    if (error) *error = nvrtc().error();
    return false;
  }

  std::vector<const char*> header_names_c;
  std::vector<const char*> header_sources_c;
  size_t num_headers = header_sources.size();
  header_names_c.reserve(num_headers);
  header_sources_c.reserve(num_headers);
  for (const auto& name_source : header_sources) {
    header_names_c.push_back(name_source.first.c_str());
    header_sources_c.push_back(name_source.second.c_str());
  }

  std::vector<const char*> options_c;
  options_c.reserve(options.size());
  for (const std::string& option : options) {
    if (nvrtc().get_version() < 11010) {
      // This NVRTC doesn't support specifying c++03 explicitly, so remove it.
      // TODO: Should also support "-std" and "c++03" as separate entries.
      if (option == "-std=c++03" || option == "--std=c++03") continue;
    }
    options_c.push_back(option.c_str());
  }

#define JITIFY_CHECK_NVRTC(call)                                      \
  do {                                                                \
    nvrtcResult jitify_nvrtc_ret = call;                              \
    if (jitify_nvrtc_ret != NVRTC_SUCCESS) {                          \
      if (error) *error = nvrtc().GetErrorString()(jitify_nvrtc_ret); \
      return false;                                                   \
    }                                                                 \
  } while (0)

  nvrtcProgram nvrtc_program;
  JITIFY_CHECK_NVRTC(nvrtc().CreateProgram()(
      &nvrtc_program, source.c_str(), name.c_str(), (int)num_headers,
      header_sources_c.data(), header_names_c.data()));
  struct ScopedNvrtcProgramDestroyer {
    nvrtcProgram& nvrtc_program_;
    ~ScopedNvrtcProgramDestroyer() {
      nvrtc().DestroyProgram()(&nvrtc_program_);
    }
  } nvrtc_program_scope_guard{nvrtc_program};

  for (const auto& name_expression : name_expressions) {
    JITIFY_CHECK_NVRTC(
        nvrtc().AddNameExpression()(nvrtc_program, name_expression.c_str()));
  }

  nvrtcResult ret = nvrtc().CompileProgram()(
      nvrtc_program, (int)options_c.size(), options_c.data());
  if (log) {
    size_t log_size;
    JITIFY_CHECK_NVRTC(nvrtc().GetProgramLogSize()(nvrtc_program, &log_size));
    // Note: log_size includes NULL terminator, and std::string is guaranteed to
    // include its own.
    log->resize(log_size - 1);
    JITIFY_CHECK_NVRTC(nvrtc().GetProgramLog()(nvrtc_program, &(*log)[0]));
  }
  JITIFY_CHECK_NVRTC(ret);
  if (ptx) {
    size_t ptx_size;
    JITIFY_CHECK_NVRTC(nvrtc().GetPTXSize()(nvrtc_program, &ptx_size));
    if (ptx_size == 1) ptx_size = 0;  // WAR for issue in CUDA 11.4 NVRTC -dlto
    if (ptx_size) {
      // Note: ptx_size includes NULL terminator, and std::string is guaranteed
      // to include its own.
      ptx->resize(ptx_size - 1);
      JITIFY_CHECK_NVRTC(nvrtc().GetPTX()(nvrtc_program, &(*ptx)[0]));
    }
  }

  // Note that direct-to-CUBIN compilation is only supported with NVRTC >= 11.2.
  if (cubin && nvrtc().GetCUBIN()) {
    size_t cubin_size;
    JITIFY_CHECK_NVRTC(nvrtc().GetCUBINSize()(nvrtc_program, &cubin_size));
    if (cubin_size) {
      cubin->resize(cubin_size, 'x');
      JITIFY_CHECK_NVRTC(nvrtc().GetCUBIN()(nvrtc_program, &(*cubin)[0]));
    }
  }

  // Note that NVVM compilation is only supported with NVRTC >= 11.4.
  if (nvvm && nvrtc().GetNVVM()) {
    size_t nvvm_size;
    JITIFY_CHECK_NVRTC(nvrtc().GetNVVMSize()(nvrtc_program, &nvvm_size));
    if (nvvm_size) {
      nvvm->resize(nvvm_size, 'x');
      JITIFY_CHECK_NVRTC(nvrtc().GetNVVM()(nvrtc_program, &(*nvvm)[0]));
    }
  }

  for (const auto& name_expression : name_expressions) {
    const char* lowered_name_c;
    JITIFY_CHECK_NVRTC(nvrtc().GetLoweredName()(
        nvrtc_program, name_expression.c_str(), &lowered_name_c));
    lowered_name_map->emplace(name_expression, lowered_name_c);
  }

  if (ptx && lowered_name_map) {
    // Automatically add global variables to lowered_name_map. This avoids
    // needing to specify them explicitly in name_expressions. Note that this
    // does not support template variables.
    find_lowered_global_variables(*ptx, lowered_name_map);
  }

#undef JITIFY_CHECK_NVRTC
  return true;
}

inline StringVec split_string(std::string str, long maxsplit = -1,
                              std::string delims = " \t") {
  StringVec results;
  if (maxsplit == 0) {
    results.push_back(str);
    return results;
  }
  // Note: +1 to include NULL-terminator
  std::vector<char> v_str(str.c_str(), str.c_str() + (str.size() + 1));
  char* c_str = v_str.data();
  char* saveptr = c_str;
  char* token = nullptr;
  for (long i = 0; i != maxsplit; ++i) {
    token = ::strtok_r(c_str, delims.c_str(), &saveptr);
    c_str = 0;
    if (!token) {
      return results;
    }
    results.push_back(token);
  }
  // Check if there's a final piece
  token += std::strlen(token) + 1;
  if (token - v_str.data() < (ptrdiff_t)str.size()) {
    // Find the start of the final piece
    token += std::strspn(token, delims.c_str());
    if (*token) {
      results.push_back(token);
    }
  }
  return results;
}

inline bool ptx_parse_decl_name(const std::string& line, std::string* name) {
  size_t name_end = line.find_first_of("[;");
  if (name_end == std::string::npos) {
    // Failed to parse .global/.const declaration in PTX: expected a semicolon.
    return false;
  }
  size_t name_start_minus1 = line.find_last_of(" \t", name_end);
  if (name_start_minus1 == std::string::npos) {
    // Failed to parse .global/.const declaration in PTX: expected whitespace.
    return false;
  }
  size_t name_start = name_start_minus1 + 1;
  *name = line.substr(name_start, name_end - name_start);
  return true;
}

inline bool ptx_remove_unused_globals(std::string* ptx) {
  std::istringstream iss(*ptx);
  StringVec lines;
  std::unordered_map<size_t, std::string> line_num_to_global_name;
  std::unordered_set<std::string> name_set;
  for (std::string line; std::getline(iss, line);) {
    size_t line_num = lines.size();
    lines.push_back(line);
    auto terms = split_string(line);
    if (terms.size() <= 1) continue;  // Ignore lines with no arguments
    if (terms[0].substr(0, 2) == "//") continue;  // Ignore comment lines
    if (terms[0].substr(0, 7) == ".global" ||
        terms[0].substr(0, 6) == ".const") {
      std::string decl_name;
      if (!ptx_parse_decl_name(line, &decl_name)) return false;
      line_num_to_global_name.emplace(line_num, std::move(decl_name));
      continue;
    }
    if (terms[0][0] == '.') continue;  // Ignore .version, .reg, .param etc.
    // Note: The first term will always be an instruction name; starting at 1
    // also allows unchecked inspection of the previous term.
    for (int i = 1; i < (int)terms.size(); ++i) {
      if (terms[i].substr(0, 2) == "//") break;  // Ignore comments
      // Note: The characters '.' and '%' are not treated as delimiters.
      const char* token_delims = " \t()[]{},;+-*/~&|^?:=!<>\"'\\";
      for (auto token : split_string(terms[i], -1, token_delims)) {
        if (  // Ignore non-names
            !(std::isalpha(token[0]) || token[0] == '_' || token[0] == '$') ||
            token.find('.') != std::string::npos ||
            // Ignore variable/parameter declarations
            terms[i - 1][0] == '.' ||
            // Ignore branch instructions
            (token == "bra" && terms[i - 1][0] == '@') ||
            // Ignore branch labels
            (token.substr(0, 2) == "BB" &&
             terms[i - 1].substr(0, 3) == "bra")) {
          continue;
        }
        name_set.insert(token);
      }
    }
  }
  std::ostringstream oss;
  for (size_t line_num = 0; line_num < lines.size(); ++line_num) {
    auto it = line_num_to_global_name.find(line_num);
    if (it != line_num_to_global_name.end()) {
      const std::string& name = it->second;
      if (!name_set.count(name)) {
        continue;  // Remove unused .global declaration.
      }
    }
    oss << lines[line_num] << '\n';
  }
  *ptx = oss.str();
  return true;
}

// Returns false if there is a syntax error in the options.
inline bool copy_compiler_option_for_driver_ptxas(
    const StringVec& compiler_options, StringVec* linker_options,
    bool has_value, StringRef short_key, StringRef long_key,
    StringRef output_key = {}) {
  // First check if the option is already specified in linker_options.
  for (const std::string& raw_option : *linker_options) {
    StringRef option = ltrim(raw_option);
    if (startswith(option, short_key) || startswith(option, long_key)) {
      return true;  // Already specified, don't do anything
    }
  }
  // Now find the option in compiler_options.
  StringSlice key, val;
  for (size_t i = 0; i < compiler_options.size(); ++i) {
    StringRef option = ltrim(compiler_options[i]);
    if (startswith(option, short_key) || startswith(option, long_key)) {
      if (!has_value) {
        key = option;
        break;
      }
      size_t eql_pos = option.find('=', short_key.size());
      if (eql_pos != std::string::npos) {
        key = option.substr(0, eql_pos);
        val = option.substr(eql_pos + 1);
      } else {
        key = option;
        if (i + 1 == compiler_options.size()) return false;  // Syntax error
        val = compiler_options[i + 1];
      }
      break;
    }
  }
  if (key.empty()) return true;  // Nothing to copy
  if (!output_key.empty()) {
    key = output_key;
  }
  if (has_value) {
    linker_options->push_back(string_concat(key, "=", val));
  } else {
    linker_options->push_back(static_cast<std::string>(key));
  }
  return true;
}

}  // namespace detail

inline CompiledProgram CompiledProgram::compile(
    const std::string& name, const std::string& source,
    const StringMap& header_sources, const StringVec& name_expressions,
    StringVec compiler_options, StringVec linker_options) {
  std::string error;
  if (!detail::process_architecture_flags(&compiler_options, &linker_options,
                                          &error)) {
    return Error("Failed to process architecture flags: " + error);
  }
  detail::add_std_flag_if_not_specified(&compiler_options, "c++11");
  detail::add_default_device_flag_if_not_specified(&compiler_options);
  bool should_remove_unused_globals = detail::pop_flag(
      &compiler_options, "-remove-unused-globals", "--remove-unused-globals");
  std::string log, ptx, cubin, nvvm;
  StringMap lowered_name_map;
  if (!detail::compile_program(name, source, header_sources, compiler_options,
                               &error, &log, &ptx, &cubin, &nvvm,
                               name_expressions, &lowered_name_map)) {
    std::string options_str = detail::string_join(
        compiler_options, " ", "Compiler options: \"", "\"\n");
    std::vector<std::string> header_names;
    header_names.reserve(header_sources.size());
    for (const auto& item : header_sources) {
      header_names.push_back(item.first);
    }
    std::sort(header_names.begin(), header_names.end());
    std::string headers_str =
        detail::string_join(header_names, "\n  ", "Header names:\n  ", "\n");
    return Error("Compilation failed: " + error + "\n" + options_str +
                 headers_str + "\n" + log);
  }
  if (!ptx.empty() && should_remove_unused_globals) {
    detail::ptx_remove_unused_globals(&ptx);  // Ignores errors from this
  }

  // We copy certain compiler options to linker_options so that they are used if
  // the linker does ptx->cubin compilation prior to linking. This allows users
  // to specify these options in compiler_options without having to worry about
  // whether they also need to be passed in linker_options.
  detail::copy_compiler_option_for_driver_ptxas(
      compiler_options, &linker_options, /*has_value = */ false, "-G",
      "--device-debug");
  detail::copy_compiler_option_for_driver_ptxas(
      compiler_options, &linker_options, /*has_value = */ false, "-lineinfo",
      "--generate-line-info",
      "--generate-line-info");  // Note that linker doesn't support "-lineinfo"
  detail::copy_compiler_option_for_driver_ptxas(
      compiler_options, &linker_options, /*has_value = */ true, "-maxrregcount",
      "--maxrregcount");

  return CompiledProgram(std::move(ptx), std::move(cubin), std::move(nvvm),
                         std::move(lowered_name_map), std::move(linker_options),
                         std::move(log), std::move(compiler_options));
}

namespace detail {

// Merges two maps. If two keys compare equal, the value from the second map (b)
// is used. If a or b is empty, this function is a no-op (no redundant copies
// are made); otherwise, the merge result is stored in *tmp and a reference to
// it is returned.
template <typename Key, typename Value>
const std::unordered_map<Key, Value>& merge(
    const std::unordered_map<Key, Value>& a,
    const std::unordered_map<Key, Value>& b,
    std::unordered_map<Key, Value>* tmp) {
  if (a.empty()) return b;
  if (b.empty()) return a;
  tmp->clear();
  tmp->reserve(a.size() + b.size());
  tmp->insert(b.begin(), b.end());  // b given priority for equal keys
  tmp->insert(a.begin(), a.end());
  return *tmp;
}

}  // namespace detail

/*! An object containing CUDA source and header strings and associated metadata.
 */
class PreprocessedProgramData
    : public serialization::Serializable<PreprocessedProgramData> {
  std::string name_;
  std::string source_;
  StringMap header_sources_;
  // Note that these are the options to be passed on to Compiled/LinkedProgram.
  StringVec remaining_compiler_options_;
  StringVec remaining_linker_options_;
  std::string header_log_;
  std::string compile_log_;

  JITIFY_DEFINE_SERIALIZABLE_MEMBERS(PreprocessedProgramData, name_, source_,
                                     header_sources_,
                                     remaining_compiler_options_,
                                     remaining_linker_options_)

 public:
  PreprocessedProgramData() = default;
  PreprocessedProgramData(std::string name, std::string source,
                          StringMap header_sources = {},
                          StringVec remaining_compiler_options = {},
                          StringVec remaining_linker_options = {},
                          std::string header_log = {},
                          std::string compile_log = {})
      : name_(std::move(name)),
        source_(std::move(source)),
        header_sources_(std::move(header_sources)),
        remaining_compiler_options_(std::move(remaining_compiler_options)),
        remaining_linker_options_(std::move(remaining_linker_options)),
        header_log_(std::move(header_log)),
        compile_log_(std::move(compile_log)) {}

  // Custom serialize methods to allow exclusion of headers.
  /*! Serialize the preprocessed program to a stream.
   *  \param stream The stream to output serialized data to.
   *  \param include_headers Whether to include headers in the serialized
   *    output. If false, only the program source is included.
   */
  void serialize(std::ostream& stream, bool include_headers = true) const {
    serialization::serialize(
        stream, name_, source_, include_headers ? header_sources_ : StringMap(),
        remaining_compiler_options_, remaining_linker_options_);
  }

  /*! Serialize the preprocessed program to a string.
   *  \param include_headers Whether to include headers in the serialized
   *    output. If false, only the program source is included.
   *  \return A string containing the serialized data.
   */
  std::string serialize(bool include_headers = true) const {
    std::ostringstream ss(std::stringstream::binary);
    serialize(ss, include_headers);
    return ss.str();
  }

  /*! Get the name of the program. */
  const std::string& name() const { return name_; }
  /*! Get the CUDA source code of the program. */
  const std::string& source() const { return source_; }
  /*! Get the header sources map. */
  const StringMap& header_sources() const { return header_sources_; }
  /*! Get the remaining options that will be passed on to the compiler. */
  const StringVec& remaining_compiler_options() const {
    return remaining_compiler_options_;
  }
  /*! Get the remaining options that will be passed on to the linker. */
  const StringVec& remaining_linker_options() const {
    return remaining_linker_options_;
  }
  /*! Get the log of header lookups made during preprocessing. */
  const std::string& header_log() const { return header_log_; }
  /*! Get the log from the compiler invocation made during preprocessing. */
  const std::string& compile_log() const { return compile_log_; }

  /*! Compile the program to PTX (and maybe CUBIN).
   *  \param name_expressions List of name expressions to include during
   *    compilation (e.g.,
   *    `{&quot;my_namespace::my_kernel<123, float>&quot;, &quot;v<7>&quot;}`).
   *  \param extra_header_sources List of additional header names and sources to
   *    include during compilation. These are added to those already specified
   *    in the preprocessed program, replacing them if names match.
   *  \param extra_compiler_options List of additional compiler options.
   *  \param extra_linker_options List of additional linker options.
   *  \return A CompiledProgram object that contains either a valid
   *    CompiledProgramData object or an error state.
   */
  CompiledProgram compile(const StringVec& name_expressions = {},
                          const StringMap& extra_header_sources = {},
                          StringVec extra_compiler_options = {},
                          StringVec extra_linker_options = {}) const {
    StringMap combined_header_sources;
    const StringMap& combined_header_sources_ref = detail::merge(
        header_sources_, extra_header_sources, &combined_header_sources);
    extra_compiler_options.insert(extra_compiler_options.begin(),
                                  remaining_compiler_options_.begin(),
                                  remaining_compiler_options_.end());
    extra_linker_options.insert(extra_linker_options.begin(),
                                remaining_linker_options_.begin(),
                                remaining_linker_options_.end());
    return CompiledProgram::compile(
        name_, source_, combined_header_sources_ref, name_expressions,
        std::move(extra_compiler_options), std::move(extra_linker_options));
  }

  /*! Compile the program to PTX (and maybe CUBIN).
   *  \param name_expression Name expression to include during compilation
   *    (e.g.,`&quot;my_namespace::my_kernel<123, float>&quot;`).
   *  \param extra_header_sources List of additional header names and sources to
   *    include during compilation. These are added to those already specified
   *    in the preprocessed program, replacing them if names match.
   *  \param extra_compiler_options List of additional compiler options.
   *  \param extra_linker_options List of additional linker options.
   *  \return A CompiledProgram object that contains either a valid
   *    CompiledProgramData object or an error state.
   */
  CompiledProgram compile(const std::string& name_expression,
                          const StringMap& extra_header_sources = {},
                          StringVec extra_compiler_options = {},
                          StringVec extra_linker_options = {}) const {
    // Allow name_expression="" to be passed instead of name_expression={}
    // (which is ambiguous with the overload above that takes a StringVec).
    StringVec name_expressions =
        name_expression.empty() ? StringVec() : StringVec({name_expression});
    return compile(name_expressions, extra_header_sources,
                   std::move(extra_compiler_options),
                   std::move(extra_linker_options));
  }

  /*! Compile, link, and load the preprocessed program.
   *  \return A LoadedProgram object that contains either a valid
   *    LoadedProgramData object or an error state.
   *  \see compile
   */
  LoadedProgram load(const StringVec& name_expressions = {},
                     const StringMap& extra_header_sources = {},
                     StringVec extra_compiler_options = {},
                     StringVec extra_linker_options = {}) const {
    CompiledProgram compiled = compile(name_expressions, extra_header_sources,
                                       std::move(extra_compiler_options),
                                       std::move(extra_linker_options));
    if (!compiled) return LoadedProgram::Error(compiled.error());
    LinkedProgram linked = compiled->link();
    if (!linked) return LoadedProgram::Error(linked.error());
    return linked->load();
  }

  /*! Compile, link, load, and get a kernel from the preprocessed program.
   *  \return A Kernel object that contains either a valid KernelData object or
   *    an error state.
   *  \see compile
   */
  Kernel get_kernel(std::string name, StringVec other_name_expressions = {},
                    const StringMap& extra_header_sources = {},
                    StringVec extra_compiler_options = {},
                    StringVec extra_linker_options = {}) const {
    other_name_expressions.push_back(name);
    CompiledProgram compiled = compile(
        other_name_expressions, extra_header_sources,
        std::move(extra_compiler_options), std::move(extra_linker_options));
    if (!compiled) return Kernel::Error(compiled.error());
    LinkedProgram linked = compiled->link();
    if (!linked) return Kernel::Error(linked.error());
    LoadedProgram loaded = linked->load();
    if (!loaded) return Kernel::Error(loaded.error());
    return Kernel::get_kernel(std::move(*loaded), std::move(name));
  }
};

using FileCallback = std::function<bool(const std::string&, std::string*)>;

class PreprocessedProgram
    : public detail::FallibleObjectBase<PreprocessedProgram,
                                        PreprocessedProgramData> {
  friend class detail::FallibleObjectBase<PreprocessedProgram,
                                          PreprocessedProgramData>;
  using super_type =
      detail::FallibleObjectBase<PreprocessedProgram, PreprocessedProgramData>;
  using super_type::super_type;

 public:
  /*! \see ProgramData::preprocess */
  static PreprocessedProgram preprocess(std::string name, std::string source,
                                        StringMap header_sources = {},
                                        StringVec compiler_options = {},
                                        StringVec linker_options = {},
                                        FileCallback header_callback = nullptr);
};

namespace detail {

// TODO: Check all of these WARs.
static const char* const jitsafe_header_preinclude_h =
    R"(
// WAR for Thrust and CUB.
#ifdef __host__
#undef __host__
#endif
#define __host__

// WAR to allow exceptions to be parsed.
#define try
#define catch(...)
)"
#if defined(_WIN32) || defined(_WIN64)
    // WAR for NVRTC <= 11.0 not defining _WIN64.
    R"(
#ifndef _WIN64
#define _WIN64 1
#endif
)"
#endif
    ;

#define JITIFY_DEFINE_C_AND_CXX_HEADERS_EX(name, header, std_and_global_impl, \
                                           std_only_impl)                     \
  static const char* const jitsafe_header_##name##_h =                        \
      "#pragma once\n" header "\n" std_and_global_impl;                       \
  static const char* const jitsafe_header_c##name =                           \
      "#pragma once\n" header                                                 \
      "\n"                                                                    \
      "namespace std {\n" std_only_impl std_and_global_impl                   \
      "}  // namespace std\n" std_and_global_impl

#define JITIFY_DEFINE_C_AND_CXX_HEADERS(name, header, std_and_global_impl) \
  JITIFY_DEFINE_C_AND_CXX_HEADERS_EX(name, header, std_and_global_impl, "")

JITIFY_DEFINE_C_AND_CXX_HEADERS(assert, "", "");

JITIFY_DEFINE_C_AND_CXX_HEADERS(float, R"(
#define FLT_RADIX       2
#define FLT_MANT_DIG    24
#define DBL_MANT_DIG    53
#define FLT_DIG         6
#define DBL_DIG         15
#define FLT_MIN_EXP     -125
#define DBL_MIN_EXP     -1021
#define FLT_MIN_10_EXP  -37
#define DBL_MIN_10_EXP  -307
#define FLT_MAX_EXP     128
#define DBL_MAX_EXP     1024
#define FLT_MAX_10_EXP  38
#define DBL_MAX_10_EXP  308
#define FLT_MAX         3.4028234e38f
#define DBL_MAX         1.7976931348623157e308
#define FLT_EPSILON     1.19209289e-7f
#define DBL_EPSILON     2.220440492503130e-16
#define FLT_MIN         1.1754943e-38f
#define DBL_MIN         2.2250738585072013e-308
#define FLT_ROUNDS      1
#if defined __cplusplus && __cplusplus >= 201103L
#define FLT_EVAL_METHOD 0
#define DECIMAL_DIG     21
#endif
)",
                                "");

JITIFY_DEFINE_C_AND_CXX_HEADERS(limits, R"(
#if defined _WIN32 || defined _WIN64
 #define __WORDSIZE 32
#else
 #if defined __x86_64__ && !defined __ILP32__
  #define __WORDSIZE 64
 #else
  #define __WORDSIZE 32
 #endif
#endif
#define MB_LEN_MAX  16
#define CHAR_BIT    8
#define SCHAR_MIN   (-128)
#define SCHAR_MAX   127
#define UCHAR_MAX   255
#define _JITIFY_CHAR_IS_UNSIGNED ((char)-1 >= 0)
#define CHAR_MIN    (_JITIFY_CHAR_IS_UNSIGNED ? 0 : SCHAR_MIN)
#define CHAR_MAX    (_JITIFY_CHAR_IS_UNSIGNED ? UCHAR_MAX : SCHAR_MAX)
#define SHRT_MIN    (-32768)
#define SHRT_MAX    32767
#define USHRT_MAX   65535
#define INT_MIN     (-INT_MAX - 1)
#define INT_MAX     2147483647
#define UINT_MAX    4294967295U
#if __WORDSIZE == 64
 # define LONG_MAX  9223372036854775807L
#else
 # define LONG_MAX  2147483647L
#endif
#define LONG_MIN    (-LONG_MAX - 1L)
#if __WORDSIZE == 64
 #define ULONG_MAX  18446744073709551615UL
#else
 #define ULONG_MAX  4294967295UL
#endif
#define LLONG_MAX  9223372036854775807LL
#define LLONG_MIN  (-LLONG_MAX - 1LL)
#define ULLONG_MAX 18446744073709551615ULL
)",
                                "");

// Note: Global namespace already includes CUDA math funcs
JITIFY_DEFINE_C_AND_CXX_HEADERS_EX(math, "#define M_PI 3.14159265358979323846",
                                   "", R"(
#if __cplusplus >= 201103L
#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f)                       \
  inline double f(double x) { return ::f(x); }                  \
  inline float f##f(float x) { return ::f(x); }                 \
  /*inline long double f##l(long double x) { return ::f(x); }*/ \
  inline float f(float x) { return ::f(x); }                    \
  /*inline long double f(long double x)    { return ::f(x); }*/
#else
#define DEFINE_MATH_UNARY_FUNC_WRAPPER(f)       \
  inline double f(double x) { return ::f(x); }  \
  inline float f##f(float x) { return ::f(x); } \
  /*inline long double f##l(long double x) { return ::f(x); }*/
#endif
DEFINE_MATH_UNARY_FUNC_WRAPPER(cos)
DEFINE_MATH_UNARY_FUNC_WRAPPER(sin)
DEFINE_MATH_UNARY_FUNC_WRAPPER(tan)
DEFINE_MATH_UNARY_FUNC_WRAPPER(acos)
DEFINE_MATH_UNARY_FUNC_WRAPPER(asin)
DEFINE_MATH_UNARY_FUNC_WRAPPER(atan)
template <typename T>
inline T atan2(T y, T x) {
  return ::atan2(y, x);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(cosh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(sinh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(tanh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(exp)
template <typename T>
inline T frexp(T x, int* exp) {
  return ::frexp(x, exp);
}
template <typename T>
inline T ldexp(T x, int exp) {
  return ::ldexp(x, exp);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(log)
DEFINE_MATH_UNARY_FUNC_WRAPPER(log10)
template <typename T>
inline T modf(T x, T* intpart) {
  return ::modf(x, intpart);
}
template <typename T>
inline T pow(T x, T y) {
  return ::pow(x, y);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(sqrt)
DEFINE_MATH_UNARY_FUNC_WRAPPER(ceil)
DEFINE_MATH_UNARY_FUNC_WRAPPER(floor)
template <typename T>
inline T fmod(T n, T d) {
  return ::fmod(n, d);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(fabs)
template <typename T>
inline T abs(T x) {
  return ::abs(x);
}
#if __cplusplus >= 201103L
DEFINE_MATH_UNARY_FUNC_WRAPPER(acosh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(asinh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(atanh)
DEFINE_MATH_UNARY_FUNC_WRAPPER(exp2)
DEFINE_MATH_UNARY_FUNC_WRAPPER(expm1)
template <typename T>
inline int ilogb(T x) {
  return ::ilogb(x);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(log1p)
DEFINE_MATH_UNARY_FUNC_WRAPPER(log2)
DEFINE_MATH_UNARY_FUNC_WRAPPER(logb)
template <typename T>
inline T scalbn(T x, int n) {
  return ::scalbn(x, n);
}
template <typename T>
inline T scalbln(T x, long n) {
  return ::scalbn(x, n);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(cbrt)
template <typename T>
inline T hypot(T x, T y) {
  return ::hypot(x, y);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(erf)
DEFINE_MATH_UNARY_FUNC_WRAPPER(erfc)
DEFINE_MATH_UNARY_FUNC_WRAPPER(tgamma)
DEFINE_MATH_UNARY_FUNC_WRAPPER(lgamma)
DEFINE_MATH_UNARY_FUNC_WRAPPER(trunc)
DEFINE_MATH_UNARY_FUNC_WRAPPER(round)
template <typename T>
inline long lround(T x) {
  return ::lround(x);
}
template <typename T>
inline long long llround(T x) {
  return ::llround(x);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(rint)
template <typename T>
inline long lrint(T x) {
  return ::lrint(x);
}
template <typename T>
inline long long llrint(T x) {
  return ::llrint(x);
}
DEFINE_MATH_UNARY_FUNC_WRAPPER(nearbyint)
// TODO: remainder, remquo, copysign, nan, nextafter, nexttoward, fdim,
// fmax, fmin, fma
#endif  // __cplusplus >= 201103L
#undef DEFINE_MATH_UNARY_FUNC_WRAPPER
)");

// TODO: offsetof
JITIFY_DEFINE_C_AND_CXX_HEADERS_EX(stddef, "#include <climits>", R"(
#if __cplusplus >= 201103L
typedef decltype(nullptr) nullptr_t;
#if defined(_MSC_VER)
  typedef double max_align_t;
#elif defined(__APPLE__)
  typedef long double max_align_t;
#else
  // Define max_align_t to match the GCC definition.
  typedef struct {
    long long __jitify_max_align_nonce1
        __attribute__((__aligned__(__alignof__(long long))));
    long double __jitify_max_align_nonce2
        __attribute__((__aligned__(__alignof__(long double))));
  } max_align_t;
#endif
#endif  // __cplusplus >= 201103L
#if __cplusplus >= 201703L
enum class byte : unsigned char {};
#endif  // __cplusplus >= 201703L
)",
                                   R"(
// NVRTC provides built-in definitions of ::size_t and ::ptrdiff_t.
using ::size_t;
using ::ptrdiff_t;
)");

JITIFY_DEFINE_C_AND_CXX_HEADERS(stdint, R"(
#include <climits>
#define INT8_MIN SCHAR_MIN
#define INT16_MIN SHRT_MIN
#define INT32_MIN INT_MIN
#define INT64_MIN LLONG_MIN
#define INT8_MAX SCHAR_MAX
#define INT16_MAX SHRT_MAX
#define INT32_MAX INT_MAX
#define INT64_MAX LLONG_MAX
#define UINT8_MAX UCHAR_MAX
#define UINT16_MAX USHRT_MAX
#define UINT32_MAX UINT_MAX
#define UINT64_MAX ULLONG_MAX
#define INTPTR_MIN LONG_MIN
#define INTMAX_MIN LLONG_MIN
#define INTPTR_MAX LONG_MAX
#define INTMAX_MAX LLONG_MAX
#define UINTPTR_MAX ULONG_MAX
#define UINTMAX_MAX ULLONG_MAX
#define PTRDIFF_MIN INTPTR_MIN
#define PTRDIFF_MAX INTPTR_MAX
#define SIZE_MAX UINT64_MAX
#define _JITIFY_WCHAR_T_IS_UNSIGNED ((wchar_t)-1 >= 0)
#define WCHAR_MIN                                                      \
    (sizeof(wchar_t) == 2 ? _JITIFY_WCHAR_T_IS_UNSIGNED ? 0 : SHRT_MIN \
                          : _JITIFY_WCHAR_T_IS_UNSIGNED ? 0 : INT_MIN)
#define WCHAR_MAX                                                              \
    (sizeof(wchar_t) == 2 ? _JITIFY_WCHAR_T_IS_UNSIGNED ? USHRT_MAX : SHRT_MAX \
                          : _JITIFY_WCHAR_T_IS_UNSIGNED ? UINT_MAX : INT_MAX)
)",
                                R"(
typedef signed char int8_t;
typedef signed short int16_t;
typedef signed int int32_t;
typedef signed long long int64_t;
typedef signed char int_fast8_t;
typedef signed short int_fast16_t;
typedef signed int int_fast32_t;
typedef signed long long int_fast64_t;
typedef signed char int_least8_t;
typedef signed short int_least16_t;
typedef signed int int_least32_t;
typedef signed long long int_least64_t;
typedef signed long long intmax_t;
typedef signed long intptr_t;  // optional
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef unsigned char uint_fast8_t;
typedef unsigned short uint_fast16_t;
typedef unsigned int uint_fast32_t;
typedef unsigned long long uint_fast64_t;
typedef unsigned char uint_least8_t;
typedef unsigned short uint_least16_t;
typedef unsigned int uint_least32_t;
typedef unsigned long long uint_least64_t;
typedef unsigned long long uintmax_t;
#if defined _WIN32 || defined _WIN64
typedef unsigned long long uintptr_t;  // optional
#else  // not Windows
typedef unsigned long uintptr_t;  // optional
#endif
)");

JITIFY_DEFINE_C_AND_CXX_HEADERS(stdio, "#include <cstddef>", R"(
using FILE = int;
int fflush(FILE* stream);
int fprintf(FILE* stream, const char* format, ...);
)");

JITIFY_DEFINE_C_AND_CXX_HEADERS(stdlib, "#include <cstddef>", "");

JITIFY_DEFINE_C_AND_CXX_HEADERS(string, "", R"(
//#include <cstddef>
char* strcpy(char* destination, const char* source);
int strcmp(const char* str1, const char* str2);
char* strerror(int errnum);
)");

JITIFY_DEFINE_C_AND_CXX_HEADERS_EX(time, R"(
#define NULL 0
#define CLOCKS_PER_SEC 1000000
)",
                                   R"(
typedef long time_t;
struct tm {
  int tm_sec;
  int tm_min;
  int tm_hour;
  int tm_mday;
  int tm_mon;
  int tm_year;
  int tm_wday;
  int tm_yday;
  int tm_isdst;
};
#if __cplusplus >= 201703L
struct timespec {
  time_t tv_sec;
  long tv_nsec;
};
#endif
)",
                                   R"(
// NVRTC provides built-in definitions of ::size_t and ::clock_t.
using ::size_t;
using ::clock_t;
)");

#undef JITIFY_DEFINE_C_AND_CXX_HEADERS
#undef JITIFY_DEFINE_C_AND_CXX_HEADERS_EX

static const char* const jitsafe_header_algorithm = R"(
#pragma once
#if __cplusplus >= 201103L
namespace std {
#if __cplusplus == 201103L
#define JITIFY_CXX14_CONSTEXPR
#else
#define JITIFY_CXX14_CONSTEXPR constexpr
#endif
template <class T>
JITIFY_CXX14_CONSTEXPR const T& max(const T& a, const T& b) {
  return (b > a) ? b : a;
}
template <class T>
JITIFY_CXX14_CONSTEXPR const T& min(const T& a, const T& b) {
  return (b < a) ? b : a;
}
#undef JITIFY_CXX14_CONSTEXPR
}  // namespace std
#endif  // __cplusplus >= 201103L
)";

// TODO: This is very incomplete.
static const char* const jitsafe_header_array = R"(
#pragma once
namespace std {
template <class T, std::size_t N>
class array {
  T data_[N];

public:
  using value_type = T;
  using size_type = size_t;
  using difference_type = ptrdiff_t;
  using reference = T&;
  using const_reference = const T&;
  using pointer = T*;
  using const_pointer = const T*;

  reference operator[](size_type pos) { return data_[pos]; }
  constexpr const_reference operator[](size_type pos) const {
    return data_[pos];
  }
};
}  // namespace std
)";

// TODO: This is incomplete.
static const char* const jitsafe_header_complex = R"(
#pragma once
namespace std {
template <typename T>
class complex {
  T real_;
  T imag_;

 public:
  complex() : real_(0), imag_(0) {}
  complex(const T& real, const T& imag) : real_(real), imag_(imag) {}
  complex(const T& real) : real_(real), imag_(static_cast<T>(0)) {}
  const T& real() const { return real_; }
  T& real() { return real_; }
  void real(const T& r) { real_ = r; }
  const T& imag() const { return imag_; }
  T& imag() { return imag_; }
  void imag(const T& i) { imag_ = i; }
  complex<T>& operator+=(const complex<T> z) {
    real_ += z.real();
    imag_ += z.imag();
    return *this;
  }
};
template <typename T>
complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
  return complex<T>(lhs.real() * rhs.real() - lhs.imag() * rhs.imag(),
                    lhs.real() * rhs.imag() + lhs.imag() * rhs.real());
}
template <typename T>
complex<T> operator*(const complex<T>& lhs, const T& rhs) {
  return complexs<T>(lhs.real() * rhs, lhs.imag() * rhs);
}
template <typename T>
complex<T> operator*(const T& lhs, const complex<T>& rhs) {
  return complexs<T>(rhs.real() * lhs, rhs.imag() * lhs);
}
}  // namespace std
)";

static const char* const jitsafe_header_initializer_list = R"(
#pragma once
namespace std {
// NVRTC provides std::initializer by default.
}  // namespace std
)";

static const char* const jitsafe_header_iostream = R"(
#pragma once
#include <istream>
#include <ostream>
)";

static const char* const jitsafe_header_istream = R"(
#pragma once
namespace std {
template <class CharT, class Traits = void>  // = std::char_traits<CharT>>
struct basic_istream {};
typedef basic_istream<char> istream;
}  // namespace std
)";

static const char* const jitsafe_header_iterator = R"(
#pragma once
namespace std {
struct output_iterator_tag {};
struct input_iterator_tag {};
struct forward_iterator_tag {};
struct bidirectional_iterator_tag {};
struct random_access_iterator_tag {};
template<class Iterator>
struct iterator_traits {
  typedef typename Iterator::iterator_category iterator_category;
  typedef typename Iterator::value_type        value_type;
  typedef typename Iterator::difference_type   difference_type;
  typedef typename Iterator::pointer           pointer;
  typedef typename Iterator::reference         reference;
};
template<class T>
struct iterator_traits<T*> {
  typedef random_access_iterator_tag iterator_category;
  typedef T                          value_type;
  typedef ptrdiff_t                  difference_type;
  typedef T*                         pointer;
  typedef T&                         reference;
};
template<class T>
struct iterator_traits<T const*> {
  typedef random_access_iterator_tag iterator_category;
  typedef T                          value_type;
  typedef ptrdiff_t                  difference_type;
  typedef T const*                   pointer;
  typedef T const&                   reference;
};
}  // namespace std
)";

static const char* const jitsafe_header_limits = R"(
#pragma once
#include <cfloat>
#include <climits>
#include <cstdint>
// TODO: epsilon(), infinity(), etc
namespace std {
namespace __jitify_detail {
#if __cplusplus >= 201103L
#define JITIFY_CXX11_CONSTEXPR constexpr
#define JITIFY_CXX11_NOEXCEPT noexcept
#else
#define JITIFY_CXX11_CONSTEXPR
#define JITIFY_CXX11_NOEXCEPT
#endif

struct FloatLimits {
#if __cplusplus >= 201103L
  static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ float lowest()
      JITIFY_CXX11_NOEXCEPT {
    return -FLT_MAX;
  }
  static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ float min()
      JITIFY_CXX11_NOEXCEPT {
    return FLT_MIN;
  }
  static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ float max()
      JITIFY_CXX11_NOEXCEPT {
    return FLT_MAX;
  }
#endif  // __cplusplus >= 201103L
  enum {
    is_specialized = true,
    is_signed = true,
    is_integer = false,
    is_exact = false,
    has_infinity = true,
    has_quiet_NaN = true,
    has_signaling_NaN = true,
    has_denorm = 1,
    has_denorm_loss = true,
    round_style = 1,
    is_iec559 = true,
    is_bounded = true,
    is_modulo = false,
    digits = 24,
    digits10 = 6,
    max_digits10 = 9,
    radix = 2,
    min_exponent = -125,
    min_exponent10 = -37,
    max_exponent = 128,
    max_exponent10 = 38,
    tinyness_before = false,
    traps = false
  };
};
struct DoubleLimits {
#if __cplusplus >= 201103L
  static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ double
  lowest() noexcept {
    return -DBL_MAX;
  }
  static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ double
  min() noexcept {
    return DBL_MIN;
  }
  static JITIFY_CXX11_CONSTEXPR inline __host__ __device__ double
  max() noexcept {
    return DBL_MAX;
  }
#endif  // __cplusplus >= 201103L
  enum {
    is_specialized = true,
    is_signed = true,
    is_integer = false,
    is_exact = false,
    has_infinity = true,
    has_quiet_NaN = true,
    has_signaling_NaN = true,
    has_denorm = 1,
    has_denorm_loss = true,
    round_style = 1,
    is_iec559 = true,
    is_bounded = true,
    is_modulo = false,
    digits = 53,
    digits10 = 15,
    max_digits10 = 17,
    radix = 2,
    min_exponent = -1021,
    min_exponent10 = -307,
    max_exponent = 1024,
    max_exponent10 = 308,
    tinyness_before = false,
    traps = false
  };
};
template <class T, T Min, T Max, int Digits = -1>
struct IntegerLimits {
  static inline __host__ __device__ T min() { return Min; }
  static inline __host__ __device__ T max() { return Max; }
#if __cplusplus >= 201103L
  static constexpr inline __host__ __device__ T lowest() noexcept {
    return Min;
  }
#endif  // __cplusplus >= 201103L
  enum {
    is_specialized = true,
    digits = (Digits == -1) ? (int)(sizeof(T) * 8 - (Min != 0)) : Digits,
    digits10 = (digits * 30103) / 100000,
    is_signed = ((T)(-1) < 0),
    is_integer = true,
    is_exact = true,
    radix = 2,
    is_bounded = true,
    is_modulo = false
  };
};
}  // namespace __jitify_detail
template <typename T>
struct numeric_limits {
  enum { is_specialized = false };
};
template <>
struct numeric_limits<bool>
    : public __jitify_detail::IntegerLimits<bool, false, true, 1> {};
template <>
struct numeric_limits<char>
    : public __jitify_detail::IntegerLimits<char, CHAR_MIN, CHAR_MAX> {};
template <>
struct numeric_limits<signed char>
    : public __jitify_detail::IntegerLimits<signed char, SCHAR_MIN, SCHAR_MAX> {
};
template <>
struct numeric_limits<unsigned char>
    : public __jitify_detail::IntegerLimits<unsigned char, 0, UCHAR_MAX> {};
template <>
struct numeric_limits<wchar_t>
    : public __jitify_detail::IntegerLimits<wchar_t, WCHAR_MIN, WCHAR_MAX> {};
template <>
struct numeric_limits<short>
    : public __jitify_detail::IntegerLimits<short, SHRT_MIN, SHRT_MAX> {};
template <>
struct numeric_limits<unsigned short>
    : public __jitify_detail::IntegerLimits<unsigned short, 0, USHRT_MAX> {};
template <>
struct numeric_limits<int>
    : public __jitify_detail::IntegerLimits<int, INT_MIN, INT_MAX> {};
template <>
struct numeric_limits<unsigned int>
    : public __jitify_detail::IntegerLimits<unsigned int, 0, UINT_MAX> {};
template <>
struct numeric_limits<long>
    : public __jitify_detail::IntegerLimits<long, LONG_MIN, LONG_MAX> {};
template <>
struct numeric_limits<unsigned long>
    : public __jitify_detail::IntegerLimits<unsigned long, 0, ULONG_MAX> {};
template <>
struct numeric_limits<long long>
    : public __jitify_detail::IntegerLimits<long long, LLONG_MIN, LLONG_MAX> {};
template <>
struct numeric_limits<unsigned long long>
    : public __jitify_detail::IntegerLimits<unsigned long long, 0, ULLONG_MAX> {
};
template <>
struct numeric_limits<float> : public __jitify_detail::FloatLimits {};
template <>
struct numeric_limits<double> : public __jitify_detail::DoubleLimits {};
}  // namespace std
)";

// TODO: This is incomplete.
static const char* const jitsafe_header_mutex = R"(
#pragma once
#if __cplusplus >= 201103L
namespace std {
class mutex {
 public:
  void lock();
  bool try_lock();
  void unlock();
};
}  // namespace std
#endif  // __cplusplus >= 201103L
)";

static const char* const jitsafe_header_ostream = R"(
#pragma once
namespace std {
template <class CharT, class Traits = void>  // = std::char_traits<CharT>>
struct basic_ostream {};
typedef basic_ostream<char> ostream;
ostream& endl(ostream& os);
ostream& operator<<(ostream&, ostream& (*f)(ostream&));
template <class CharT, class Traits>
basic_ostream<CharT, Traits>& endl(basic_ostream<CharT, Traits>& os);
template <class CharT, class Traits>
basic_ostream<CharT, Traits>& operator<<(basic_ostream<CharT, Traits>& os,
                                         const char* c);
#if __cplusplus >= 201103L
template <class CharT, class Traits, class T>
basic_ostream<CharT, Traits>& operator<<(basic_ostream<CharT, Traits>&& os,
                                         const T& value);
#endif  // __cplusplus >= 201103L
}  // namespace std
)";

static const char* const jitsafe_header_sstream = R"(
#pragma once
#include <ostream>
#include <istream>
)";

static const char* const jitsafe_header_stdexcept = R"(
#pragma once
#include <string>
namespace std {
struct runtime_error {
  explicit runtime_error( const string& what_arg );
  explicit runtime_error( const char* what_arg );
  virtual const char* what() const;
};
}  // namespace std
)";

static const char* const jitsafe_header_string = R"(
#pragma once
namespace std {
template <class CharT, class Traits = void, class Allocator = void>
struct basic_string {
  basic_string();
  basic_string(const CharT* s);  //, const Allocator& alloc = Allocator());
  const CharT* c_str() const;
  bool empty() const;
  void operator+=(const char*);
  void operator+=(const basic_string&);
};
typedef basic_string<char> string;
}  // namespace std
)";

static const char* const jitsafe_header_tuple = R"(
#pragma once
#if __cplusplus >= 201103L
namespace std {
template <class... Types> class tuple;
}  // namespace std
#endif  // c++11
)";

// TODO: This is incomplete.
static const char* const jitsafe_header_type_traits = R"(
#pragma once
#if __cplusplus >= 201103L
namespace std {

template <bool B, class T = void>
struct enable_if {};
template <class T>
struct enable_if<true, T> {
  typedef T type;
};
#if __cplusplus >= 201402L
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
#endif

struct true_type {
  enum { value = true };
  operator bool() const { return true; }
};
struct false_type {
  enum { value = false };
  operator bool() const { return false; }
};

template <typename T>
struct is_floating_point : false_type {};
template <>
struct is_floating_point<float> : true_type {};
template <>
struct is_floating_point<double> : true_type {};
template <>
struct is_floating_point<long double> : true_type {};

template <class T>
struct is_integral : false_type {};
template <>
struct is_integral<bool> : true_type {};
template <>
struct is_integral<char> : true_type {};
template <>
struct is_integral<signed char> : true_type {};
template <>
struct is_integral<unsigned char> : true_type {};
template <>
struct is_integral<short> : true_type {};
template <>
struct is_integral<unsigned short> : true_type {};
template <>
struct is_integral<int> : true_type {};
template <>
struct is_integral<unsigned int> : true_type {};
template <>
struct is_integral<long> : true_type {};
template <>
struct is_integral<unsigned long> : true_type {};
template <>
struct is_integral<long long> : true_type {};
template <>
struct is_integral<unsigned long long> : true_type {};

template <typename T>
struct is_signed : false_type {};
template <>
struct is_signed<float> : true_type {};
template <>
struct is_signed<double> : true_type {};
template <>
struct is_signed<long double> : true_type {};
template <>
struct is_signed<signed char> : true_type {};
template <>
struct is_signed<short> : true_type {};
template <>
struct is_signed<int> : true_type {};
template <>
struct is_signed<long> : true_type {};
template <>
struct is_signed<long long> : true_type {};

template <typename T>
struct is_unsigned : false_type {};
template <>
struct is_unsigned<unsigned char> : true_type {};
template <>
struct is_unsigned<unsigned short> : true_type {};
template <>
struct is_unsigned<unsigned int> : true_type {};
template <>
struct is_unsigned<unsigned long> : true_type {};
template <>
struct is_unsigned<unsigned long long> : true_type {};

template <typename T, typename U>
struct is_same : false_type {};
template <typename T>
struct is_same<T, T> : true_type {};

template <class T>
struct is_array : false_type {};
template <class T>
struct is_array<T[]> : true_type {};
template <class T, size_t N>
struct is_array<T[N]> : true_type {};

// This is a partial implementation only of is_function.
template <class>
struct is_function : false_type {};
template <class Ret, class... Args>
struct is_function<Ret(Args...)> : true_type {};  // regular
template <class Ret, class... Args>
struct is_function<Ret(Args......)> : true_type {};  // variadic

template <class>
struct result_of;
template <class F, typename... Args>
struct result_of<F(Args...)> {
  // TODO: This is a hack; a proper implem is quite complicated.
  typedef typename F::result_type type;
};

template <class T>
struct remove_reference {
  typedef T type;
};
template <class T>
struct remove_reference<T&> {
  typedef T type;
};
template <class T>
struct remove_reference<T&&> {
  typedef T type;
};
#if __cplusplus >= 201402L
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
#endif

template <class T>
struct remove_extent {
  typedef T type;
};
template <class T>
struct remove_extent<T[]> {
  typedef T type;
};
template <class T, size_t N>
struct remove_extent<T[N]> {
  typedef T type;
};
#if __cplusplus >= 201402L
template <class T>
using remove_extent_t = typename remove_extent<T>::type;
#endif

template <class T>
struct remove_const {
  typedef T type;
};
template <class T>
struct remove_const<const T> {
  typedef T type;
};
template <class T>
struct remove_volatile {
  typedef T type;
};
template <class T>
struct remove_volatile<volatile T> {
  typedef T type;
};
template <class T>
struct remove_cv {
  typedef typename remove_volatile<typename remove_const<T>::type>::type type;
};
#if __cplusplus >= 201402L
template <class T>
using remove_cv_t = typename remove_cv<T>::type;
template <class T>
using remove_const_t = typename remove_const<T>::type;
template <class T>
using remove_volatile_t = typename remove_volatile<T>::type;
#endif

template <bool B, class T, class F>
struct conditional {
  typedef T type;
};
template <class T, class F>
struct conditional<false, T, F> {
  typedef F type;
};
#if __cplusplus >= 201402L
template <bool B, class T, class F>
using conditional_t = typename conditional<B, T, F>::type;
#endif

namespace __jitify_detail {
template <class T, bool is_function_type = false>
struct add_pointer {
  using type = typename remove_reference<T>::type*;
};
template <class T>
struct add_pointer<T, true> {
  using type = T;
};
template <class T, class... Args>
struct add_pointer<T(Args...), true> {
  using type = T (*)(Args...);
};
template <class T, class... Args>
struct add_pointer<T(Args..., ...), true> {
  using type = T (*)(Args..., ...);
};
}  // namespace __jitify_detail
template <class T>
struct add_pointer : __jitify_detail::add_pointer<T, is_function<T>::value> {};
#if __cplusplus >= 201402L
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
#endif

template <class T>
struct decay {
 private:
  typedef typename remove_reference<T>::type U;

 public:
  typedef typename conditional<
      is_array<U>::value, typename remove_extent<U>::type*,
      typename conditional<is_function<U>::value, typename add_pointer<U>::type,
                           typename remove_cv<U>::type>::type>::type type;
};
#if __cplusplus >= 201402L
template <class T>
using decay_t = typename decay<T>::type;
#endif

template <class T, T v>
struct integral_constant {
  static constexpr T value = v;
  typedef T value_type;
  typedef integral_constant type;  // using injected-class-name
  constexpr operator value_type() const noexcept { return value; }
#if __cplusplus >= 201402L
  constexpr value_type operator()() const noexcept { return value; }
#endif
};

template <class T>
struct is_lvalue_reference : false_type {};
template <class T>
struct is_lvalue_reference<T&> : true_type {};

template <class T>
struct is_rvalue_reference : false_type {};
template <class T>
struct is_rvalue_reference<T&&> : true_type {};

namespace __jitify_detail {
template <class T>
struct type_identity {
  using type = T;
};
template <class T>
auto add_lvalue_reference(int) -> type_identity<T&>;
template <class T>
auto add_lvalue_reference(...) -> type_identity<T>;
template <class T>
auto add_rvalue_reference(int) -> type_identity<T&&>;
template <class T>
auto add_rvalue_reference(...) -> type_identity<T>;
}  // namespace __jitify_detail

template <class T>
struct add_lvalue_reference
    : decltype(__jitify_detail::add_lvalue_reference<T>(0)) {};
template <class T>
struct add_rvalue_reference
    : decltype(__jitify_detail::add_rvalue_reference<T>(0)) {};
#if __cplusplus >= 201402L
template <class T>
using add_lvalue_reference_t = typename add_lvalue_reference<T>::type;
template <class T>
using add_rvalue_reference_t = typename add_rvalue_reference<T>::type;
#endif

template <typename T>
struct is_const : public false_type {};
template <typename T>
struct is_const<const T> : public true_type {};

template <typename T>
struct is_volatile : public false_type {};
template <typename T>
struct is_volatile<volatile T> : public true_type {};

template <typename T>
struct is_void : public false_type {};
template <>
struct is_void<void> : public true_type {};
template <>
struct is_void<const void> : public true_type {};

template <typename T>
struct is_reference : public false_type {};
template <typename T>
struct is_reference<T&> : public true_type {};

template <typename _Tp,
          bool = (is_void<_Tp>::value || is_reference<_Tp>::value)>
struct __add_reference_helper {
  typedef _Tp& type;
};

template <typename _Tp>
struct __add_reference_helper<_Tp, true> {
  typedef _Tp type;
};
template <typename _Tp>
struct add_reference : public __add_reference_helper<_Tp> {};

namespace __jitify_detail {
template <typename T>
struct is_int_or_cref {
  typedef typename remove_reference<T>::type type_sans_ref;
  static const bool value =
      (is_integral<T>::value ||
       (is_integral<type_sans_ref>::value && is_const<type_sans_ref>::value &&
        !is_volatile<type_sans_ref>::value));
};
template <typename From, typename To>
struct is_convertible_sfinae {
 private:
  typedef char yes;
  typedef struct {
    char two_chars[2];
  } no;
  static inline yes test(To) { return yes(); }
  static inline no test(...) { return no(); }
  static inline typename remove_reference<From>::type& from() {
    typename remove_reference<From>::type* ptr = 0;
    return *ptr;
  }

 public:
  static const bool value = sizeof(test(from())) == sizeof(yes);
};
template <typename From, typename To>
struct is_convertible_needs_simple_test {
  static const bool from_is_void = is_void<From>::value;
  static const bool to_is_void = is_void<To>::value;
  static const bool from_is_float =
      is_floating_point<typename remove_reference<From>::type>::value;
  static const bool to_is_int_or_cref = is_int_or_cref<To>::value;
  static const bool value =
      (from_is_void || to_is_void || (from_is_float && to_is_int_or_cref));
};
template <typename From, typename To,
          bool = is_convertible_needs_simple_test<From, To>::value>
struct is_convertible {
  static const bool value = (is_void<To>::value || (is_int_or_cref<To>::value &&
                                                    !is_void<From>::value));
};
template <typename From, typename To>
struct is_convertible<From, To, false> {
  static const bool value =
      (is_convertible_sfinae<typename add_reference<From>::type, To>::value);
};
}  // namespace __jitify_detail
// Note: Implementation of is_convertible taken from Thrust's pre-C++11 path.
template <typename From, typename To>
struct is_convertible
    : public integral_constant<
          bool, __jitify_detail::is_convertible<From, To>::value> {};

template <class A, class B>
struct is_base_of {};

template <size_t len, size_t alignment>
struct aligned_storage {
  struct type {
    alignas(alignment) char data[len];
  };
};
template <class T>
struct alignment_of : integral_constant<size_t, alignof(T)> {};

template <typename T> struct make_unsigned;
template <> struct make_unsigned<signed char>        { typedef unsigned char type; };
template <> struct make_unsigned<signed short>       { typedef unsigned short type; };
template <> struct make_unsigned<signed int>         { typedef unsigned int type; };
template <> struct make_unsigned<signed long>        { typedef unsigned long type; };
template <> struct make_unsigned<signed long long>   { typedef unsigned long long type; };
template <> struct make_unsigned<unsigned char>      { typedef unsigned char type; };
template <> struct make_unsigned<unsigned short>     { typedef unsigned short type; };
template <> struct make_unsigned<unsigned int>       { typedef unsigned int type; };
template <> struct make_unsigned<unsigned long>      { typedef unsigned long type; };
template <> struct make_unsigned<unsigned long long> { typedef unsigned long long type; };
template <> struct make_unsigned<char>               { typedef unsigned char type; };
#if defined _WIN32 || defined _WIN64
template <> struct make_unsigned<wchar_t>            { typedef unsigned short type; };
#else
template <> struct make_unsigned<wchar_t>            { typedef unsigned int type; };
#endif
template <typename T> struct make_signed;
template <> struct make_signed<signed char>        { typedef signed char type; };
template <> struct make_signed<signed short>       { typedef signed short type; };
template <> struct make_signed<signed int>         { typedef signed int type; };
template <> struct make_signed<signed long>        { typedef signed long type; };
template <> struct make_signed<signed long long>   { typedef signed long long type; };
template <> struct make_signed<unsigned char>      { typedef signed char type; };
template <> struct make_signed<unsigned short>     { typedef signed short type; };
template <> struct make_signed<unsigned int>       { typedef signed int type; };
template <> struct make_signed<unsigned long>      { typedef signed long type; };
template <> struct make_signed<unsigned long long> { typedef signed long long type; };
template <> struct make_signed<char>               { typedef signed char type; };
#if defined _WIN32 || defined _WIN64
template <> struct make_signed<wchar_t>            { typedef signed short type; };
#else
template <> struct make_signed<wchar_t>            { typedef signed int type; };
#endif

}  // namespace std
#endif  // __cplusplus >= 201103L
)";

static const char* const jitsafe_header_utility = R"(
#pragma once
namespace std {
template <class T1, class T2>
struct pair {
  T1 first;
  T2 second;
  inline pair() {}
  inline pair(const T1& first_, const T2& second_)
      : first(first_), second(second_) {}
  // TODO: Standard includes many more constructors...
  // TODO: Comparison operators.
};
template <class T1, class T2>
pair<T1, T2> make_pair(const T1& first, const T2& second) {
  return pair<T1, T2>(first, second);
}
}  // namespace std
)";

static const char* const jitsafe_header_vector = R"(
#pragma once
namespace std {
template <class T, class Allocator = void>  // = std::allocator>
struct vector {};
}  // namespace std
)";

static const char* const jitsafe_header_memory_h = R"(
#pragma once
#include <string.h>
)";

// WAR: These need to be pre-added as a workaround for NVRTC implicitly using
// /usr/include as an include path. The other built-in headers will be included
// lazily as needed.
static const std::unordered_set<std::string>& get_workaround_system_headers() {
  static const std::unordered_set<std::string>& workaround_system_header_names =
      {
          "assert.h", "limits.h", "math.h", "stdint.h", "stdio.h",
          "stdlib.h", "string.h", "time.h", "memory.h",
      };
  return workaround_system_header_names;
}

static const StringMap& get_jitsafe_headers_map() {
  static const StringMap jitsafe_headers_map = {
      {"jitify_preinclude.h", jitsafe_header_preinclude_h},
      {"assert.h", jitsafe_header_assert_h},
      {"cassert", jitsafe_header_cassert},
      {"float.h", jitsafe_header_float_h},
      {"cfloat", jitsafe_header_cfloat},
      {"limits.h", jitsafe_header_limits_h},
      {"climits", jitsafe_header_climits},
      {"math.h", jitsafe_header_math_h},
      {"cmath", jitsafe_header_cmath},
      {"stddef.h", jitsafe_header_stddef_h},
      {"cstddef", jitsafe_header_cstddef},
      {"stdint.h", jitsafe_header_stdint_h},
      {"cstdint", jitsafe_header_cstdint},
      {"stdio.h", jitsafe_header_stdio_h},
      {"cstdio", jitsafe_header_cstdio},
      {"stdlib.h", jitsafe_header_stdlib_h},
      {"cstdlib", jitsafe_header_cstdlib},
      {"string.h", jitsafe_header_string_h},
      {"cstring", jitsafe_header_cstring},
      {"time.h", jitsafe_header_time_h},
      {"ctime", jitsafe_header_ctime},
      {"algorithm", jitsafe_header_algorithm},
      {"array", jitsafe_header_array},
      {"complex", jitsafe_header_complex},
      {"initializer_list", jitsafe_header_initializer_list},
      {"iostream", jitsafe_header_iostream},
      {"istream", jitsafe_header_istream},
      {"iterator", jitsafe_header_iterator},
      {"limits", jitsafe_header_limits},
      {"mutex", jitsafe_header_mutex},
      {"ostream", jitsafe_header_ostream},
      {"sstream", jitsafe_header_sstream},
      {"stdexcept", jitsafe_header_stdexcept},
      {"string", jitsafe_header_string},
      {"tuple", jitsafe_header_tuple},
      {"utility", jitsafe_header_utility},
      {"type_traits", jitsafe_header_type_traits},
      {"vector", jitsafe_header_vector},
      {"memory.h", jitsafe_header_memory_h},
  };
  return jitsafe_headers_map;
}

inline bool extract_include_info_from_compile_error(const std::string& log,
                                                    std::string* name,
                                                    std::string* parent,
                                                    int* line_num) {
  static const StringVec pattern = {"could not open source file \"",
                                    "cannot open source file \""};
  for (auto& p : pattern) {
    size_t beg = log.find(p);
    if (beg != std::string::npos) {
      beg += p.size();
      size_t end = log.find("\"", beg);
      *name = log.substr(beg, end - beg);

      size_t line_beg = log.rfind("\n", beg);
      if (line_beg == std::string::npos) {
        line_beg = 0;
      } else {
        line_beg += 1;
      }

      size_t split = log.find("(", line_beg);
      *parent = log.substr(line_beg, split - line_beg);
      *line_num = std::atoi(
          log.substr(split + 1, log.find(")", split + 1) - (split + 1))
              .c_str());

      return true;
    }
  }
  return false;
}

// Returns the offset of the beginning of the specified line, taking into
// account any #line directives.
// TODO: It's not clear what this should do when there is a #line directive
// that skips lines (e.g., line_num = 2 and there is a '#line 1' several lines
// into the source, resulting in an ambiguity).
inline size_t find_source_line(StringRef source, int line_num) {
  // TODO: This is not robust to `#line` inside comments, strings etc.
  size_t beg = 0;
  // HACK: This is a WAR for the ambiguity introduced by jitify's include guard
  // that is 2 lines followed by '#line 1', when line_num <= 3.
  if (startswith(source, "#ifndef JITIFY_INCLUDE_GUARD_")) {
    beg = source.find("#line 1\n") + 8;
  }
  for (int i = 1; i < line_num; ++i) {
    beg = source.find_first_of("\n#", beg);
    if (beg == std::string::npos) return beg;
    if (source[beg] == '#' && source.substr(beg, 5) == "#line" &&
        std::isspace((unsigned char)source[beg + 5])) {
      // Found a #line directive, parse it and reset the line numbering.
      beg += 5;
      while (std::isspace((unsigned char)source[++beg]))
        ;
      size_t num_beg = beg;
      while (std::isdigit((unsigned char)source[++beg]))
        ;
      size_t num_end = beg;
      int num = std::atoi(
          std::string(source.substr(num_beg, num_end - num_beg)).c_str());
      i = num - 1;
      beg = source.find_first_of("\n", beg);
      if (beg == std::string::npos) return beg;
    } else if (source[beg] == '#') {
      // This was just some other # token, don't count it as a new line.
      --i;
    }
    ++beg;
  }
  return beg;
}

inline bool is_include_directive_with_quotes(StringRef source, int line_num,
                                             std::string* error = nullptr) {
  // TODO: This implementation does not handle things like
  // "#define INC <foo>\n #include INC", which Thrust does in some headers.
  size_t beg = find_source_line(source, line_num);
  if (beg == std::string::npos) {
    if (error) *error = "EOF reached before source line was found";
    return false;
  }
  // TODO: This is not robust to inline comments, strings etc.
  beg = source.find("include", beg);
  if (beg == std::string::npos) {
    if (error) *error = "Line does not contain 'include'";
    return false;
  }
  beg += 7;
  beg = source.find_first_of("\"<", beg);
  if (beg == std::string::npos) {
    if (error) *error = "Did not find expected '\"' or '<' character";
    return false;
  }
  return source[beg] == '"';
}

// Elides "/." and "/.." tokens from path. Returns empty string if illformed.
inline std::string path_simplify(StringRef path) {
  StringVec dirs;
  std::string cur_dir;
  bool after_slash = false;
  for (int i = 0; i < (int)path.size(); ++i) {
    if (path[i] == '/') {
      if (after_slash) continue;  // Ignore repeat slashes
      after_slash = true;
      if (cur_dir == ".." && !dirs.empty() && dirs.back() != "..") {
        if (dirs.size() == 1 && dirs.front().empty()) {
          return {};  // Bad path: back-traversals exceed depth of absolute path
        }
        dirs.pop_back();
      } else if (cur_dir != ".") {  // Ignore /./
        dirs.push_back(cur_dir);
      }
      cur_dir.clear();
    } else {
      after_slash = false;
      cur_dir.push_back(path[i]);
    }
  }
  if (!after_slash) {
    dirs.push_back(cur_dir);
  }
  std::ostringstream ss;
  for (int i = 0; i < (int)dirs.size() - 1; ++i) {
    ss << dirs[i] << "/";
  }
  if (!dirs.empty()) ss << dirs.back();
  if (after_slash) ss << "/";
  return ss.str();
}

inline bool read_text_file(const std::string& fullpath, std::string* content) {
  std::ifstream file(fullpath.c_str());
  if (!file) return false;
  std::stringstream buf;
  buf << file.rdbuf();
  *content = buf.str();
  return true;
}

static const char* const kJitifyBuiltinHeaderPrefix = "__jitify_builtin";
static const char* const kJitifyCallbackHeaderPrefix = "__jitify_callback";

// Searches for the specified header and loads its contents into *source and its
// full path into *fullpath. Returns false if not found.
inline bool load_header_impl(const std::string& filename,
                             const StringVec& include_paths,
                             StringRef current_dir, bool search_current_dir,
                             bool search_builtin_headers,
                             FileCallback header_callback, std::string* source,
                             std::string* fullpath) {
  // Try loading from header callback.
  if (header_callback) {
    *fullpath = path_join(kJitifyCallbackHeaderPrefix, filename);
    if (header_callback(filename, source)) return true;
  }
  // Try loading from filesystem.
  if (search_current_dir) {
    *fullpath = path_join(current_dir, filename);
    if (read_text_file(*fullpath, source)) return true;
  }
  // Search include directories.
  for (const std::string& include_path : include_paths) {
    *fullpath = path_join(include_path, filename);
    if (read_text_file(*fullpath, source)) return true;
  }
  // Try loading from builtin headers.
  if (search_builtin_headers) {
    *fullpath = path_join(kJitifyBuiltinHeaderPrefix, filename);
    auto iter = get_jitsafe_headers_map().find(filename);
    if (iter != get_jitsafe_headers_map().end()) {
      *source = iter->second;
      return true;
    }
  }
  return false;
}

enum class HeaderLoadStatus {
  FAILED = 0,
  ALREADY_LOADED = 1,
  NEWLY_LOADED = 2,
};

// Searches for the specified header and adds its contents to *sources and its
// simplified full path to *fullpaths (if provided). Returns 0 if not found, -1
// if alreay found, or 1 if successfully loaded.
inline HeaderLoadStatus load_header(
    const std::string& filename, const StringVec& include_paths,
    StringRef current_dir, bool search_current_dir, bool search_builtin_headers,
    FileCallback header_callback, StringMap* sources, StringMap* fullpaths) {
  if (sources->count(filename)) {
    return HeaderLoadStatus::ALREADY_LOADED;
  }
  std::string source, fullpath;
  if (!load_header_impl(filename, include_paths, current_dir,
                        search_current_dir, search_builtin_headers,
                        header_callback, &source, &fullpath)) {
    return HeaderLoadStatus::FAILED;
  }
  sources->emplace(filename, source);
  if (fullpaths) {
    // Record the full file path corresponding to this include name.
    fullpaths->emplace(filename, path_simplify(fullpath));
  }
  return HeaderLoadStatus::NEWLY_LOADED;
}

// Replaces std with cuda::std so that the jit-safe libcudacxx implementations
// are used instead of the unsafe standard implementations.
inline std::string replace_std_with_cuda_std(std::string source) {
  static const std::regex re_qualified_name(
      R"(::cuda::std::|\bcuda::std::|::std::|\bstd::)", std::regex::optimize);
  // TODO: This isn't safe because it might already be ns cuda { ns std { } }.
  // static const std::regex re_namespace(R"(\bnamespace\s+std\s*\{)",
  //                                     std::regex::optimize);
  source = std::regex_replace(source, re_qualified_name, "::cuda::std::");
  // source = std::regex_replace(source, re_namespace, "namespace cuda::std {");
  return source;
}

// Helper class for basic lexing of C++ source code.
class CppLexer {
  const char* current_;

  bool isspace(char c) const {
    return std::isspace(static_cast<unsigned char>(c));
  }

 public:
  CppLexer(const char* str) : current_(str) {}
  const char* current() const { return current_; }
  char advance() { return *current_++; }
  void skip(int n) { current_ += n; }
  char peek(int i = 0) const { return *(current_ + i); }
  bool match(char c) { return peek() == c ? advance() : false; }
  bool match(const char* s) {
    int i;
    for (i = 0; s[i]; ++i) {
      if (!peek(i) || peek(i) != s[i]) return false;
    }
    current_ += i;
    return true;
  }
  bool match_whitespace() {
    // Includes line continuations.
    return (isspace(peek()) || (peek() == '\\' && peek(1) == '\n')) ? advance()
                                                                    : false;
  }
  const char* whitespace() {
    while (match_whitespace()) {
    }
    // while (isspace(peek()) || (peek() == '\\' && peek(1) == '\n')) advance();
    return current_;
  }
  const char* escapable_char_delimited_span(char delim) {
    while (peek() && (peek() != delim || peek(-1) == '\\')) advance();
    if (peek() == delim) {
      skip(1);
    } else {
      // Error, unexpected end of string.
    }
    return current_;
  }
  // Excludes the ending newline char.
  const char* line() { return escapable_char_delimited_span('\n') - 1; }
  // These all include the ending delimiter chars.
  const char* string_literal() { return escapable_char_delimited_span('"'); }
  const char* char_literal() { return escapable_char_delimited_span('\''); }
  const char* delimited_span(const char* delim, int delim_size) {
    auto peek_equals_delimiter = [&] {
      for (int i = 0; i < delim_size; ++i) {
        if (peek(i) != delim[i]) return false;
      }
      return true;
    };
    while (peek() && !peek_equals_delimiter()) advance();
    if (peek() == delim[0]) {
      skip(delim_size);
    } else {
      // Error, unexpected end of string.
    }
    return current_;
  }
  const char* block_comment() { return delimited_span("*/", 2); }
  const char* raw_string_literal() {
    const char* delim_beg = current_;
    while (peek() && peek() != '(') advance();
    std::string delim;
    delim.reserve(current_ - delim_beg + 2);
    delim += ')';
    delim.append(delim_beg, current_);
    delim += '"';
    return delimited_span(delim.c_str(), (int)delim.size());
  }
};

inline bool find_pragma_once(const std::string& source, size_t* begin_ptr,
                             size_t* end_ptr) {
  // Match string literals, comments (/), and preprocessor directives (#).
  const char* match_chars = "\"'R/#";
  size_t pos = 0;
  while ((pos = source.find_first_of(match_chars, pos)) != std::string::npos) {
    const char* beg = source.c_str() + pos;
    CppLexer lexer(beg);
    bool hit = false;
    const char* end = [&] {
      // clang-format off
      switch (lexer.advance()) {
        case '"':  return lexer.string_literal();
        case '\'': return lexer.char_literal();
        case 'R':  return lexer.match('"') ? lexer.raw_string_literal() :
                          lexer.current();
        case '/':  return lexer.match('/') ? lexer.line() :
                          lexer.match('*') ? lexer.block_comment() :
                          lexer.current();
        case '#':  return (hit = lexer.match("pragma") &&
                           lexer.match_whitespace() &&
                           (lexer.whitespace(), lexer.match("once"))),
                          lexer.current();
        default:   return lexer.current(); // Should never be reached
      }
      // clang-format on
    }();
    if (hit) {
      *begin_ptr = pos;
      *end_ptr = end - source.c_str();
      return true;
    }
    pos += end - beg;
  }
  return false;
}

inline std::string remove_cpp_comments_and_line_continuations(
    const std::string& source) {
  std::string result;
  result.reserve(source.size());
  size_t old_pos = 0, pos;
  // Match string literals, comments (forward slashes), and line continuations
  // (backslashes).
  const char* match_chars = "\"'R/\\";
  while ((pos = source.find_first_of(match_chars, old_pos)) !=
         std::string::npos) {
    result.append(source, old_pos, pos - old_pos);
    const char* beg = source.c_str() + pos;
    CppLexer lexer(beg);
    const char* end = [&] {
      // clang-format off
      switch (lexer.advance()) {
        case '"':  return lexer.string_literal();
        case '\'': return lexer.char_literal();
        case 'R':  return lexer.match('"') ? lexer.raw_string_literal() :
                          lexer.current();
        case '/':  return lexer.match('/') ? lexer.line() :
                          lexer.match('*') ? lexer.block_comment() :
                          lexer.current();
        // Match line continuation (escaped newline).
        // TODO: Line continuations inside string literals will not be matched
        // here. Would need to use a separate pass that only matches them and
        // raw strings.
        case '\\': return lexer.match('\n'), lexer.current();
        default:   return lexer.current(); // Should never be reached
      }
      // clang-format on
    }();
    old_pos = end - source.c_str();
    if (end - beg == 1 || *beg == '"' || *beg == '\'' || *beg == 'R') {
      // Keep single characters ('/') and string literals.
      result.append(beg, end);
    } else {
      // Elide comments and line continuations.
    }
  }
  result.append(source, old_pos);
  return result;
}

// This removes most but not all whitespace. Remaining whitespace is tricky to
// handle safely+efficiently.
inline std::string remove_cpp_whitespace(const std::string& source) {
  std::string result;
  result.reserve(source.size());
  size_t old_pos = 0, pos;
  // Match string literals, preprocessor directives, whitespace, and chars that
  // can safely have whitespace after them removed.
  bool inside_directive = false;
  const char* match_chars = "\"'R# \f\n\r\t\v.,;!|~^()[]{}";
  while ((pos = source.find_first_of(match_chars, old_pos)) !=
         std::string::npos) {
    result.append(source, old_pos, pos - old_pos);
    const char* beg = source.c_str() + pos;
    CppLexer lexer(beg);
    bool end_of_directive = false;
    bool is_whitespace = false;
    const char* end = [&] {
      // clang-format off
      char c = lexer.advance();
      switch (c) {
        case '"':  return lexer.string_literal();
        case '\'': return lexer.char_literal();
        case 'R':  return lexer.match('"') ? lexer.raw_string_literal() :
                          lexer.current();
        case '#':  return inside_directive = true, lexer.current();
        default:   return is_whitespace = true, lexer.whitespace();
      }
      // clang-format on
    }();
    if (inside_directive && is_whitespace && std::find(beg, end, '\n') != end) {
      inside_directive = false;
      end_of_directive = true;
    }
    old_pos = end - source.c_str();
    if ((end - beg == 1 && !std::isspace((unsigned char)*beg)) || *beg == '"' ||
        *beg == '\'' || *beg == 'R' || *beg == '#') {
      // Keep single characters ('R'), string literals, and preprocessor
      // directives.
      result.append(beg, end);
    } else {
      // Elide or replace whitespace.
      bool before_directive = !inside_directive && *end == '#';
      if (!std::isspace((unsigned char)*beg)) {
        // Remove whitespace after symbol.
        result += *beg;
        if (end_of_directive || before_directive) {
          result += '\n';
        }
      } else {
        if (end_of_directive) {
          result += '\n';
        } else {
          // A newline may already be present from a preprocessor directive.
          bool after_newline = result.empty() || result.back() == '\n';
          if (!after_newline || before_directive) {
            // Replace whitespace.
            result += before_directive ? '\n' : ' ';
          }
        }
      }
    }
  }
  result.append(source, old_pos);
  return result;
}

// WAR for #pragma once not working when there are multiple inclusions of the
// same header from different paths.
inline std::string replace_pragma_once_with_ifndef(const std::string& source) {
  constexpr const char* const kJitifyIncludeGuardPrefix =
      "JITIFY_INCLUDE_GUARD_";
  if (startswith(source, std::string("#ifndef ") + kJitifyIncludeGuardPrefix)) {
    return source;  // Already been processed
  }
  size_t begin, end;
  if (!find_pragma_once(source, &begin, &end)) return source;
  // Replace #pragma once with hash-based include guard around source.
  std::string include_guard_name =
      string_concat(kJitifyIncludeGuardPrefix, sha256(source), "\n");
  // Note: We use `#line 1` to fix the line numbering after adding additional
  // code at the beginning of the file.
  std::string prefix = string_concat("#ifndef ", include_guard_name, "#define ",
                                     include_guard_name, "#line 1\n");
  std::string suffix = "\n#endif  // " + include_guard_name;
  std::string result;
  result.reserve(prefix.size() + source.size() + suffix.size());
  result += prefix;
  result.append(source, 0, begin);
  result.append(source, end, std::string::npos);
  result += suffix;
  return result;
}

inline std::string patch_cuda_source(std::string source, bool use_cuda_std,
                                     bool replace_pragma_once) {
  if (use_cuda_std) {
    source = detail::replace_std_with_cuda_std(std::move(source));
  }
  if (replace_pragma_once) {
    source = detail::replace_pragma_once_with_ifndef(std::move(source));
  }
  // HACK This is a WAR for some CUB sources including a header they shouldn't.
  size_t pos = source.find("#include \"../util_device.cuh\"");
  if (pos != std::string::npos) {
    source[pos] = '/';  // Comment out the line
    source[pos + 1] = '/';
  }
  // HACK This is a WAR for Thrust (pre-CUDA-11) using "#define A #pragma B".
  pos = source.find("#pragma nv_exec_check_disable");
  if (pos != std::string::npos) {
    source[pos] = '/';  // Comment out the (rest of the) line
    source[pos + 1] = '/';
  }
  // HACK This is a WAR for Thrust using
  pos = source.find("__has_cpp_attribute(gnu::warn_unused_result)");
  if (pos != std::string::npos) {
    source[pos + 23] = '_';  // Replace "::" with "__".
    source[pos + 24] = '_';
  }
  return source;
}

// Removes comments and most whitespace from C++ source code.
inline std::string minify_cpp_source(const std::string& source) {
  return remove_cpp_whitespace(
      remove_cpp_comments_and_line_continuations(source));
}

inline void extract_include_paths(StringVec* options,
                                  StringVec* include_paths) {
  auto options_iter = options->begin();
  while (options_iter != options->end()) {
    const std::string& option = *options_iter;
    if (option.substr(0, 2) == "-I") {
      include_paths->push_back(option.substr(2));
      options_iter = options->erase(options_iter);
    } else {
      ++options_iter;
    }
  }
}

}  // namespace detail

inline PreprocessedProgram PreprocessedProgram::preprocess(
    std::string name, std::string source, StringMap header_sources,
    StringVec compiler_options, StringVec linker_options,
    FileCallback header_callback) {
  // Add pre-include built-in JIT-safe headers.
  bool use_system_headers_war =
      !detail::pop_flag(&compiler_options, "-no-system-headers-workaround",
                        "--no-system-headers-workaround");
#if CUDA_VERSION >= 11000
  // This issue with /usr/include always being searched is fixed in this NVRTC.
  use_system_headers_war = false;
#endif
  if (use_system_headers_war) {
    // Workaround for /usr/include always being searched by NVRTC.
    for (const std::string& header_name :
         detail::get_workaround_system_headers()) {
      const std::string& header_source =
          detail::get_jitsafe_headers_map().at(header_name);
      header_sources.emplace(header_name, header_source);
    }
  }
  if (!detail::pop_flag(&compiler_options, "-no-preinclude-workarounds",
                        "--no-preinclude-workarounds")) {
    header_sources.emplace(
        "jitify_preinclude.h",
        detail::get_jitsafe_headers_map().at("jitify_preinclude.h"));
    compiler_options.push_back("-include=jitify_preinclude.h");
  }
  detail::add_std_flag_if_not_specified(&compiler_options, "c++11");
  detail::add_default_device_flag_if_not_specified(&compiler_options);
  bool minify = detail::pop_flag(&compiler_options, "-m", "--minify");
  // TODO: This flag is experimental, because the implementation does not
  // support transformations of "namespace std {" (as used for specializations).
  bool use_cuda_std =
      detail::pop_flag(&compiler_options, "-cuda-std", "--cuda-std");
  bool replace_pragma_once = !detail::pop_flag(
      &compiler_options, "-no-replace-pragma-once", "--no-replace-pragma-once");
  bool use_builtin_headers = !detail::pop_flag(
      &compiler_options, "-no-builtin-headers", "--no-builtin-headers");

  // This is re-added to the remaining options below.
  bool should_remove_unused_globals = detail::pop_flag(
      &compiler_options, "-remove-unused-globals", "--remove-unused-globals");

  // Patch all given sources.
  source = detail::patch_cuda_source(source, use_cuda_std, replace_pragma_once);
  for (auto& name_source : header_sources) {
    const std::string& header_name = name_source.first;
    std::string& header_source = name_source.second;
    bool is_jitify_preinclude = header_name == "jitify_preinclude.h";
    bool is_cuda_std_header =
        detail::get_workaround_system_headers().count(header_name);
    header_source = detail::patch_cuda_source(
        header_source,
        use_cuda_std && !is_jitify_preinclude && !is_cuda_std_header,
        replace_pragma_once);
  }

  if (minify) {
    source = detail::minify_cpp_source(source);
    for (auto& name_source : header_sources) {
      std::string* header_source = &name_source.second;
      *header_source = detail::minify_cpp_source(*header_source);
    }
  }

  // Temporarily add the program source to header_sources for easier processing.
  header_sources.emplace(name, source);

  StringVec include_paths;
  detail::extract_include_paths(&compiler_options, &include_paths);
  std::string include_paths_msg =
      detail::string_join(include_paths, "\n", "Include paths:\n", "\n");

  if (!nvrtc()) return Error(nvrtc().error());
  // Parse architecture flags for special handling. If specified here, the arch
  // must be explicit (no auto-detection), and it will not be passed through to
  // the compile phase.
  // We don't automatically add -arch here because this may be run on a
  // different system to the one that performs the final program compilation.
  // (Users can still manually specify an architecture here if needed).
  // This also avoids needing a dependency on libcuda in this function.
  struct ArchFlag {
    int cc;
    bool is_virtual;
    explicit operator std::string() const {
      return std::string("-arch=") + (is_virtual ? "compute_" : "sm_") +
             std::to_string(cc);
    }
    bool operator==(const ArchFlag& other) const {
      return cc == other.cc && is_virtual == other.is_virtual;
    }
    size_t hash() const { return detail::fasthash64(cc) ^ (is_virtual * ~0); }
    struct Hash {
      size_t operator()(const ArchFlag& x) const { return x.hash(); }
    };
  };
  // Extract all architecture flags from compiler_options.
  std::unordered_set<ArchFlag, ArchFlag::Hash> arch_flags;
  while (true) {
    std::string error;
    size_t beg_idx, end_idx;
    bool is_virtual = false;
    int given_cc = detail::parse_arch_flag(compiler_options, &is_virtual,
                                           &error, &beg_idx, &end_idx);
    if (!error.empty()) {
      return Error("Failed to parse architecture flag: " + error);
    }
    if (given_cc == -1) {
      return Error(
          "Architecture flags passed to preprocess() must be explicit.");
    }
    if (!given_cc) break;
    if (!nvrtc().GetCUBIN() && !is_virtual) {
      // This version of NVRTC does not support direct-to-CUBIN compilation.
      // Convert real arch flags to virtual arch to avoid error from NVRTC.
      given_cc =
          detail::limit_to_supported_compute_capability(given_cc, &error);
      if (!given_cc) {
        return Error("Failed to get supported compute capability: " + error);
      }
      is_virtual = true;
    }
    arch_flags.insert({given_cc, is_virtual});
    // Remove the parsed arch flag entries; they are replaced below.
    compiler_options.erase(compiler_options.begin() + beg_idx,
                           compiler_options.begin() + end_idx);
  }
  if (arch_flags.empty()) {
    // Push a placeholder entry so that preprocessing still runs (with the
    // default arch) when none was specified by the user.
    arch_flags.insert({0, false});
  }
  // Maps header include names to their full file paths.
  StringMap header_fullpaths;
  std::string compile_log, header_log;
  // Repeat preprocessing for each specified architecture.
  for (const ArchFlag& arch_flag : arch_flags) {
    if (arch_flag.cc) {
      // Temporarily add this arch flag.
      compiler_options.push_back(static_cast<std::string>(arch_flag));
    }

    std::string compiler_options_msg = detail::string_join(
        compiler_options, " ", "Compiler options: \"", "\"\n");
    std::string compile_error;
    while (!detail::compile_program(name, source, header_sources,
                                    compiler_options, &compile_error,
                                    &compile_log)) {
      std::string include_name, include_parent;
      int line_num = 0;
      if (!detail::extract_include_info_from_compile_error(
              compile_log, &include_name, &include_parent, &line_num)) {
        // There was a non include-related compilation error.
        return Error("Compilation failed: " + compile_error + "\n" +
                     compiler_options_msg + compile_log);
      }

      bool is_included_with_quotes = false;
      if (header_sources.count(include_parent)) {
        const std::string& parent_source = header_sources.at(include_parent);
        std::string parse_error;
        is_included_with_quotes = detail::is_include_directive_with_quotes(
            parent_source, line_num, &parse_error);
        if (!parse_error.empty()) {
          // TODO: This happens with at least one Thrust header due to our
          // parsing not being robust enough. For now we just ignore it instead.
          // return Error("Internal parsing error for " + include_parent + ":" +
          //             std::to_string(line_num) + ": " + parse_error);
          // TODO: Print a warning message, but only if the "-w" option is not
          // on. std::cerr << "Warning [jitify]: Internal parsing error for "
          //          << include_parent << ":" << line_num << ": " <<
          //          parse_error;
        }
      }

      // Try to load the new header.
      // Note: This fullpath lookup is needed because the compiler error
      // messages have the include name of the header instead of its full path.
      std::string include_parent_fullpath = header_fullpaths[include_parent];
      std::string include_path = detail::path_base(include_parent_fullpath);

      using detail::HeaderLoadStatus;
      HeaderLoadStatus load_status =
          detail::load_header(include_name, include_paths, include_path,
                              /*search_current_dir = */ is_included_with_quotes,
                              use_builtin_headers, header_callback,
                              &header_sources, &header_fullpaths);
      if (load_status != HeaderLoadStatus::FAILED) {
        const std::string& header_fullpath = header_fullpaths.at(include_name);
        if (load_status == HeaderLoadStatus::NEWLY_LOADED) {
          // Patch the newly-loaded header.
          bool is_cuda_std_header =
              header_fullpath.find(detail::kJitifyBuiltinHeaderPrefix) == 0 ||
              // TODO: More robust way to detect this?
              header_fullpath.find(detail::path_join(
                  detail::path_join("cuda", "std"), "")) != std::string::npos;
          std::string* header_source = &header_sources.at(include_name);
          if (!is_cuda_std_header) {
            *header_source = detail::patch_cuda_source(
                *header_source, use_cuda_std, replace_pragma_once);
          }
          if (minify) {
            *header_source = detail::minify_cpp_source(*header_source);
          }
        }
        // Log where the header was found.
        header_log += detail::string_join(
            {"Found #include ", include_name, " from ", include_parent, ":",
             std::to_string(line_num), " [", include_parent_fullpath, "]",
             " at:\n  ", header_fullpath, "\n"},
            "");
      } else {
        // Missing header.
        std::string current_dir_msg =
            "Current path: \"" + include_path + "\"\n";
        return Error("Preprocessing failed: Header not found\n" + header_log +
                     include_paths_msg + compiler_options_msg +
                     current_dir_msg + include_parent + "(" +
                     std::to_string(line_num) + "): error: " + include_name +
                     ": [jitify] File not found");
      }
    }

    if (arch_flag.cc) {
      compiler_options.pop_back();  // Remove the temporary arch flag we added
    }
  }

  // Remove the program source from header_sources now that processing is done.
  header_sources.erase(name);

  // Re-add the -remove-unused-globals flag if it was provided.
  if (should_remove_unused_globals) {
    compiler_options.push_back("-remove-unused-globals");
  }

  return PreprocessedProgram(
      std::move(name), std::move(source), std::move(header_sources),
      std::move(compiler_options), std::move(linker_options),
      std::move(header_log), std::move(compile_log));
}

/*! An object containing CUDA source and header strings and associated metadata.
 */
class ProgramData : public serialization::Serializable<ProgramData> {
  std::string name_;
  std::string source_;
  StringMap header_sources_;

  JITIFY_DEFINE_SERIALIZABLE_MEMBERS(ProgramData, name_, source_,
                                     header_sources_)

 public:
  /*! Construct an uninitialized ProgramData object.
   */
  ProgramData() = default;
  /*! Construct a ProgramData object from CUDA source code.
   *  \param name The name of the program.
   *  \param source The CUDA source code of the program.
   *  \param header_sources (optional) A map of header names (the names by which
   *  they are `#include`d) to their source code.
   */
  ProgramData(std::string name, std::string source,
              StringMap header_sources = {})
      : name_(std::move(name)),
        source_(std::move(source)),
        header_sources_(std::move(header_sources)) {}

  /*! Get the name of the program. */
  const std::string& name() const { return name_; }
  /*! Get the CUDA source code of the program. */
  const std::string& source() const { return source_; }
  /*! Get the header sources map. */
  const StringMap& header_sources() const { return header_sources_; }

  /*! Preprocess the program to find header dependencies and apply source
   *  transformations.
   *  \param compiler_options (optional) Options to pass to the compiler.
   *  \param linker_options (optional) Options to pass to the linker (not used
   *    here, stored for when the linker is invoked).
   *  \param header_callback (optional) Callback function to obtain header
   *    sources. The function should return true if the header was obtained, or
   *    false to fall back to other means of loading the header.
   *  \return A PreprocessedProgram object that contains either a valid
   *    PreprocessedProgramData object or an error state.
   */
  PreprocessedProgram preprocess(StringVec compiler_options = {},
                                 StringVec linker_options = {},
                                 FileCallback header_callback = nullptr) const {
    return PreprocessedProgram::preprocess(name_, source_, header_sources_,
                                           compiler_options, linker_options,
                                           header_callback);
  }
};

class Program : public detail::FallibleObjectBase<Program, ProgramData> {
  using super_type = detail::FallibleObjectBase<Program, ProgramData>;
  using super_type::super_type;

 public:
  /*! Construct an uninitialized Program object.
   */
  Program() = default;

  /*! Construct a Program object from CUDA source code.
   *  \param name The name of the program.
   *  \param source The CUDA source code of the program.
   *  \param header_sources (optional) A map of header names (the names by which
   *  they are `#include`d) to their source code.
   */
  Program(std::string name, std::string source, StringMap header_sources = {})
      : super_type(std::move(name), std::move(source),
                   std::move(header_sources)) {}
};

namespace detail {

#if defined _WIN32 || defined _WIN64
using mode_t = int;
// These are not actually used.
static constexpr const mode_t kDefaultDirectoryMode = 0;
static constexpr const mode_t kDefaultFileMode = 0;
#else
static constexpr const mode_t kDefaultDirectoryMode =
    S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH;
static constexpr const mode_t kDefaultFileMode =
    S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH;
#endif

inline bool path_exists(const char* filename, bool* is_dir = nullptr) {
  struct stat stats;
  bool ret = ::stat(filename, &stats) == 0;
#define JITIFY_S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR)
  if (is_dir) *is_dir = JITIFY_S_ISDIR(stats.st_mode);
#undef JITIFY_S_ISDIR
  return ret;
}

// Opens a file, creating it if necessary.
class NewFile {
 private:
  int fd_ = -1;
  std::string filename_;
  std::string error_ = "Success";

  std::string get_error_msg(bool success, const std::string& operation) const {
    return success ? "Success"
                   : "Failed to " + operation + " " + filename_ + ": (" +
                         std::to_string(errno) + ") " + ::strerror(errno);
  }

 public:
  NewFile() = default;
  NewFile(const char* filename) { open(filename); }
  ~NewFile() { close(); }
  NewFile(const NewFile&) = delete;
  NewFile& operator=(const NewFile&) = delete;
  NewFile(NewFile&& other) noexcept
      : fd_(other.fd_),
        filename_(std::move(other.filename_)),
        error_(std::move(other.error_)) {
    other.fd_ = -1;
  }
  NewFile& operator=(NewFile&& other) noexcept {
    fd_ = other.fd_;
    other.fd_ = -1;
    filename_ = std::move(other.filename_);
    error_ = std::move(other.error_);
    return *this;
  }

  bool open(const char* filename) {
    filename_ = filename;
    close();
    // Note that opening the file works even if it is locked.
#if defined _WIN32 || defined _WIN64
    ::_sopen_s(&fd_, filename, _O_RDWR | _O_CREAT | _O_BINARY, _SH_DENYNO,
               _S_IREAD | _S_IWRITE);
#else
    fd_ = ::open(filename, O_RDWR | O_CREAT, kDefaultFileMode);
#endif
    error_ = get_error_msg(static_cast<bool>(*this), "open");
    return static_cast<bool>(*this);
  }

  void close() {
    if (fd_ != -1) {
#if defined _WIN32 || defined _WIN64
      bool success = ::_close(fd_) == 0;
#else
      // Note: Closing the file releases any lock on it held by this process.
      bool success = ::close(fd_) == 0;
#endif
      fd_ = -1;
      error_ = get_error_msg(success, "close");
    }
  }

  operator bool() const { return fd_ != -1; }

  int fd() const { return fd_; }

  // Blocks until lock on file is acquired. Returns false on error.
  bool lock() {
#if defined _WIN32 || defined _WIN64
    bool success = ::_locking(fd_, _LK_LOCK, 1) == 0;
#else
    flock fl = {};
    fl.l_type = F_WRLCK;     // Exclusive lock for writing
    fl.l_whence = SEEK_SET;  // Start at beginning of file
    // Note: The Open File Descriptor (OFD) version of this call ensures that
    // the lock is per-descriptor not per-process (and so is thread-safe).
    bool success = ::fcntl(fd_, F_OFD_SETLKW, &fl) == 0;
#endif
    error_ = get_error_msg(success, "lock");
    return success;
  }

  const std::string& error() const { return error_; }
};

// Opens or creates a file and locks it for exclusive write access. The file is
// deleted when closed. The implementation is safe for NFS, and robust against
// race conditions and sudden process termination.
// Note: This is a per-process lock, not per-thread.
class FileLock {
  NewFile file_;
  std::string filename_;
  std::string error_;

  bool acquire_lock() {
    // Note: Local instance to ensure file is not held open if locking fails.
    NewFile file;
#if defined _WIN32 || defined _WIN64
    if (!file.open(filename_.c_str())) return error_ = file.error(), false;
    if (!file.lock()) return error_ = file.error(), false;
#else
    // Despite file.lock() blocking until the lock is acquired, a loop is still
    // required here due to the possibility of the file being deleted by the
    // previous lock-holder (and possibly re-opened by someone else) between the
    // calls to open() and lock() (an unlikely race condition).
    struct stat fd_stats, file_stats;
    do {
      if (!file.open(filename_.c_str())) return error_ = file.error(), false;
      if (!file.lock()) return error_ = file.error(), false;
    } while (
        ::fstat(file.fd(), &fd_stats) != 0 ||
        ::stat(filename_.c_str(), &file_stats) != 0 ||  // File must still exist
        fd_stats.st_dev != file_stats.st_dev ||
        fd_stats.st_ino != file_stats.st_ino);  // File must still be the same
#endif
    // Success, we now exclusively own the locked file.
    file_ = std::move(file);
    return true;
  }

 public:
  FileLock() = default;
  FileLock(std::string filename) { open(std::move(filename)); }
  ~FileLock() { close(); }
  FileLock(const FileLock&) = delete;
  FileLock& operator=(const FileLock&) = delete;
  FileLock(FileLock&&) = default;
  FileLock& operator=(FileLock&&) = default;

  // Returns true if the file is open and ready for writing.
  explicit operator bool() const noexcept { return static_cast<bool>(file_); }

  const std::string& error() const { return error_; }

  int fd() const noexcept { return file_.fd(); }
  const std::string& filename() const { return filename_; }

  // Blocks until the lock filed is acquired. Returns false on error.
  bool open(std::string filename) {
    close();
    filename_ = std::move(filename);
    return acquire_lock();
  }

  void close() {
    if (file_) {
#if defined _WIN32 || defined _WIN64
      // Delete the file after releasing the lock.
      file_.close();
      std::remove(filename_.c_str());
#else
      // Delete the file before releasing the lock.
      std::remove(filename_.c_str());
      file_.close();
#endif
    }
  }
};

// Returns false on error. Returns true on success or if path already exists.
inline bool make_directory(const char* path,
                           mode_t mode = kDefaultDirectoryMode) {
  bool is_dir;
  if (path_exists(path, &is_dir)) return is_dir;
#if defined _WIN32 || defined _WIN64
  return ::_mkdir(path) == 0 || errno == EEXIST;
#else
  return ::mkdir(path, mode) == 0 || errno == EEXIST;
#endif
}

inline bool make_directories(std::string path,
                             mode_t mode = kDefaultDirectoryMode) {
#if defined _WIN32 || defined _WIN64
  // Note that Windows supports both forward and backslash path separators.
  const char* sep = "\\/";
#else
  const char* sep = "/";
#endif
  // This is based on https://stackoverflow.com/a/675193/7228843
  char* p = &path[0];
  char* s;
  while ((s = std::strpbrk(p, sep))) {
    if (s != p) {
      // Neither root nor double slash in path.
      *s = '\0';
      if (!make_directory(path.c_str(), mode)) return false;
      *s = sep[0];
    }
    p = s + 1;
  }
  return make_directory(path.c_str(), mode);
}

// Calls func(const char* filename) for each file in path (not recursively).
// Stops early if the call returns false. Returns false on error.
template <typename Func>
inline bool for_each_file_in(const std::string& path, Func func) {
#if defined(_WIN32) || defined(_WIN64)
  _WIN32_FIND_DATAA file_data;
  HANDLE handle = ::FindFirstFileA(path_join(path, "*").c_str(), &file_data);
  if (handle == INVALID_HANDLE_VALUE) {
    return ::GetLastError() == ERROR_FILE_NOT_FOUND;
  }
  while (func(file_data.cFileName) && ::FindNextFileA(handle, &file_data)) {
  }
  ::FindClose(handle);
#else
  // Note: Using `decltype(::closedir)*` gives a compiler warning.
  std::unique_ptr<DIR, int (*)(DIR*)> dir(::opendir(path.c_str()), ::closedir);
  if (!dir) return false;
  struct dirent* ent;
  while ((ent = ::readdir(dir.get())) && func(ent->d_name)) {
  }
#endif
  return true;
}

inline std::string sanitize_filename(const std::string& filename) {
  static const std::string bad_filename_chars = R"(\/:*?|"<>)";
  std::stringstream result;
  size_t beg = 0;
  while (true) {
    size_t end = filename.find_first_of(bad_filename_chars, beg);
    result << filename.substr(beg, end - beg);
    if (end == std::string::npos) break;
    // Use HTML URL encoding scheme for unsupported filename characters.
    result << "%" << std::hex << std::uppercase << (int)filename[end]
           << std::nouppercase << std::dec;
    beg = end + 1;
  }
  return result.str();
}

class LRUFileCache {
  std::string path_;
  size_t max_size_;
  std::string file_prefix_;
  std::string file_suffix_;
  std::string lock_file_name_;

  // Returns false on error.
  bool delete_lru_files_if_full() const {
    if (path_.empty()) return true;
    // We need to avoid max_size_ == 0 because this function leaves
    // max_size_ - 1 files.
    size_t max_size_not_zero = std::max(max_size_, size_t(1));

    std::multimap<double, std::string> time_sorted_cache_files;
    if (!for_each_file_in(path_, [&](const char* filename_c) {
          std::string filename(filename_c);
          if (startswith(filename, file_prefix_) &&
              endswith(filename, file_suffix_)) {
            filename = path_join(path_, filename);
            struct stat file_stats;
            // Skip file if error.
            if (::stat(filename.c_str(), &file_stats)) return true;
            double accessed_time =
#if defined(_WIN32) || defined(_WIN64)
                // Note: Some Windows filesystems only update the accessed time
                // hourly or even daily.
                (double)file_stats.st_atime;
#else
                (double)file_stats.st_atim.tv_sec * 1e9 +
                (double)file_stats.st_atim.tv_nsec;
#endif
            time_sorted_cache_files.emplace(accessed_time, std::move(filename));
          }
          return true;
        })) {
      return false;
    }
    // Note: This leaves (max_size_not_zero - 1) files.
    while (time_sorted_cache_files.size() >= max_size_not_zero) {
      auto iter = time_sorted_cache_files.begin();
      // We treat deletion of files with the same access time as all-or-nothing.
      // This ensures we don't rely on access times being unique (which is
      // unlikely on Windows due to slow/quantized updates to the access time).
      auto range = time_sorted_cache_files.equal_range(iter->first);
      ptrdiff_t num_remaining = (ptrdiff_t)time_sorted_cache_files.size() -
                                std::distance(range.first, range.second);
      // Leave at least max_size_not_zero - 1 files.
      if (num_remaining < ptrdiff_t(max_size_not_zero - 1)) break;
      while (iter != range.second) {
        const std::string& filename = iter->second;
        std::remove(filename.c_str());
        iter = time_sorted_cache_files.erase(iter);
      }
    }
    return true;
  }

 public:
  // Empty path disables file caching.
  LRUFileCache(std::string path, size_t max_size,
               const std::string& file_prefix, const std::string& file_suffix)
      : path_(std::move(path)),
        max_size_(max_size),
        file_prefix_(sanitize_filename(file_prefix)),
        file_suffix_(sanitize_filename(file_suffix)),
        lock_file_name_(path_join(path_, file_prefix_ + "lock")) {}

  template <class Construct, class Serialize, class Deserialize>
  std::string get(const std::string& name,
                  typename std::result_of<Construct()>::type* result,
                  const std::string& source_name, const std::string& cuda_source,
                  Construct construct, Serialize serialize,
                  Deserialize deserialize) const {
    if (path_.empty() || max_size_ == 0) {
      *result = construct();
    } else {
      bool is_dir;
      // Create the cache directory if necessary.
      if (!path_exists(path_.c_str(), &is_dir)) {
        if (!make_directories(path_)) {
          return "Failed to create cache directory \"" + path_ + "\"";
        }
      } else if (!is_dir) {
        return "Failed to access file cache: cache path is a file: \"" + path_ +
               "\"";
      }
      std::string filename = path_join(
          path_, file_prefix_ + sanitize_filename(name) + file_suffix_);
      // Try to open the cache file for reading.
      std::ifstream istream(filename.c_str(), std::ios::binary);
      if (istream) {
        // Found in cache, load it.
        *result = deserialize(istream);
      } else {
        // Not found in cache, acquire a file lock for exclusive access.
        FileLock file_lock(lock_file_name_.c_str());
        if (!file_lock) return file_lock.error();
        // Check for the file again in case it was created while waiting on the
        // lock.
        istream.open(filename.c_str(), std::ios::binary);
        if (istream) {
          // Found in cache now, just load it.
          file_lock.close();
          *result = deserialize(istream);
        } else {
          // We must construct the object and write it to the cache.
          auto result_tmp = construct();
          if (!result_tmp.ok()) {
            *result = std::move(result_tmp);
            return {};
          }
          // Serialize to a temp file and rename it after writing so that
          // readers don't need to obtain the lock, and also so that sudden
          // termination doesn't leave incomplete data.
          std::string temp_filename = filename + ".tmp";
          {
            std::ofstream ostream(temp_filename.c_str(), std::ios::binary);
            if (!ostream) {
              return "Failed to open cache file for writing: \"" +
                     temp_filename + "\"";
            }
            serialize(result_tmp, ostream);
          }
          if (!delete_lru_files_if_full()) {
            return "Failed to run LRU file deletion on cache directory";
          }
          // Atomically make the new cache file visible to readers.
          std::rename(temp_filename.c_str(), filename.c_str());
          *result = std::move(result_tmp);

          {
            std::ofstream ostream(source_name.c_str());
            if (!ostream) {
                return "Failed to open cache file for writing: \"" +
                    temp_filename + "\"";
            }
            ostream << cuda_source;
          }
        }
      }
    }
    return {};
  }

  size_t max_size() const { return max_size_; }

  // Changes max_size and deletes files in the cache if necessary.
  // Returns false on file deletion error (max size will still be changed).
  bool resize(size_t max_size) {
    size_t old_max_size = max_size_;
    max_size_ = max_size;
    if (max_size < old_max_size) {
      FileLock file_lock(lock_file_name_.c_str());
      if (!delete_lru_files_if_full()) return false;
    }
    return true;
  }

  // Deletes all files in the cache (does not change max_size).
  // Returns false on file deletion error.
  bool clear() {
    size_t old_max_size = max_size_;
    bool result = resize(0);
    max_size_ = old_max_size;
    return result;
  }

  string get_path() const {
    return path_;
  }
};

// This implements a LRU cache with O(1) average lookup time complexity.
template <typename Key, typename Value, typename Hash = std::hash<Key>,
          typename KeyEqual = std::equal_to<Key>>
class LRUCache {
 public:
  // Two data structures are used:
  //   cache_: Unordered map of keys to values and rank iterators.
  //   ranks_: Ordered list of cache iterators.
  // (i.e., the iterators each refer to the other data structure).
  using key_type = Key;
  using value_type = Value;
  using hasher = Hash;
  using key_equal = KeyEqual;

 private:
  struct cache_iterator_workaround;  // See definition below
  using ranks_type = std::list<cache_iterator_workaround>;

 public:
  using rank_iterator = typename ranks_type::iterator;
  using const_rank_iterator = typename ranks_type::const_iterator;
  struct value_and_rank_iter {
    value_type value;
    rank_iterator rank_iter;
  };

 private:
  using cache_type =
      std::unordered_map<key_type, value_and_rank_iter, hasher, key_equal>;

 public:
  using cache_iterator = typename cache_type::iterator;
  using const_cache_iterator = typename cache_type::const_iterator;

 private:
  // This allows the type to be forward-declared, avoiding circular typedefs.
  struct cache_iterator_workaround : cache_iterator {
    cache_iterator_workaround(const cache_iterator& other)
        : cache_iterator(other) {}
  };
  cache_type cache_;
  ranks_type ranks_;
  size_t max_size_;
  value_type nocache_value_;  // This is used if max_size_ == 0

  void touch(rank_iterator iter) {
    if (iter != ranks_.begin()) {
      // Move iter to begin.
      ranks_.splice(ranks_.begin(), ranks_, iter, std::next(iter));
    }
  }

 public:
  LRUCache(size_t max_size) : max_size_(max_size) {
    // Ensure that no iterators will be invalidated by insertions.
    cache_.reserve(max_size);
  }

  size_t max_size() const { return max_size_; }
  size_t size() const { return cache_.size(); }
  bool full() const { return cache_.size() == max_size_; }

  // Returns a reference to a value in the cache along with a bool that is set
  // to true iff the key was found in the graph. If the key was not found, the
  // value will refer either to the LRU value (which has been appropriated), or
  // a new default-constructed value.
  std::pair<value_type&, bool> operator[](const Key& key) {
    if (max_size_ == 0) return {nocache_value_, false};
    auto iter = cache_.find(key);
    if (iter != cache_.end()) {
      // Cache hit.
      auto& cache_value = iter->second;
      touch(cache_value.rank_iter);
      return {cache_value.value, true};
    } else if (cache_.size() == max_size_) {
      // Cache miss, and the cache is full, so appropriate the LRU entry.
      rank_iterator rank_iter = std::prev(ranks_.end());
      touch(rank_iter);
      iter = *rank_iter;
      // Change the key of the LRU entry to the new key.
#if __cplusplus >= 201703L
      auto node_handle = cache_.extract(iter);
      node_handle.key() = key;
      iter = cache_.insert(std::move(node_handle)).position;
#else
      auto cache_value = std::move(iter->second);
      cache_.erase(iter);
      iter = cache_.emplace(key, std::move(cache_value)).first;
#endif
      // Update the rank entry.
      *rank_iter = iter;
      return {iter->second.value, false};
    } else {
      // Cache miss, and the cache is not full, so insert a new entry.
      ranks_.push_front(cache_.end());  // Initialize with placeholder iterator
      // Insert a new default-constructed value.
      iter = cache_.emplace(key, value_and_rank_iter{{}, ranks_.begin()}).first;
      ranks_.front() = iter;  // Replace placeholder iterator with real one
      return {iter->second.value, false};
    }
  }

  cache_iterator begin() { return cache_.begin(); }
  cache_iterator end() { return cache_.end(); }
  const_cache_iterator begin() const { return cache_.begin(); }
  const_cache_iterator end() const { return cache_.end(); }
  rank_iterator ranks_begin() { return ranks_.begin(); }
  rank_iterator ranks_end() { return ranks_.end(); }
  const_rank_iterator ranks_begin() const { return ranks_.begin(); }
  const_rank_iterator ranks_end() const { return ranks_.end(); }

  void clear() {
    cache_.clear();
    ranks_.clear();
  }

  void resize(size_t max_size) {
    size_t cur_size = size();
    if (max_size > max_size_) {
      clear();  // Must clear due to rehash invalidating iterators.
      cache_.reserve(max_size);
    } else if (max_size < cur_size) {
      // Erase the LRU elements to fit the new max_size.
      auto rank_iter = std::prev(ranks_.end());
      for (size_t n = 0; n < cur_size - max_size; ++n) {
        cache_.erase(*rank_iter);
        rank_iter = std::prev(ranks_.erase(rank_iter));
      }
    }
    max_size_ = max_size;
  }
};

template <typename T>
struct StreamToString {
  std::string operator()(const T& x) const {
    std::stringstream ss;
    ss << x;
    return ss.str();
  }
};

template <typename Key, typename Hash_ = std::hash<Key>,
          typename KeyEqual = std::equal_to<Key>>
class KeyWithUInt64 {
  Key key_;
  uint64_t extra_;
  Hash_ hash_;
  KeyEqual equal_;

 public:
  KeyWithUInt64(Key key, uint64_t extra, const Hash_& hash = Hash_(),
                const KeyEqual& key_equal = KeyEqual())
      : key_(std::move(key)), extra_(extra), hash_(hash), equal_(key_equal) {}

  bool operator==(const KeyWithUInt64& rhs) const {
    return extra_ == rhs.extra_ && equal_(key_, rhs.key_);
  }

  size_t hash() const { return hash_combine(hash_(key_), fasthash64(extra_)); }

  struct Hash {
    size_t operator()(const KeyWithUInt64& x) const { return x.hash(); }
  };
};

// Default key type for ProgramCache. It represents the arguments passed to the
// get_program() method.
class AutoKey {
  StringVec name_expressions_;
  StringMap extra_header_sources_;
  StringVec extra_compiler_options_;
  StringVec extra_linker_options_;

 public:
  AutoKey(StringVec name_expressions, StringMap extra_header_sources,
          StringVec extra_compiler_options, StringVec extra_linker_options)
      : name_expressions_(std::move(name_expressions)),
        extra_header_sources_(std::move(extra_header_sources)),
        extra_compiler_options_(std::move(extra_compiler_options)),
        extra_linker_options_(std::move(extra_linker_options)) {}

  bool operator==(const AutoKey& rhs) const {
    return name_expressions_ == rhs.name_expressions_ &&
           extra_header_sources_ == rhs.extra_header_sources_ &&
           extra_compiler_options_ == rhs.extra_compiler_options_ &&
           extra_linker_options_ == rhs.extra_linker_options_;
  }

  size_t hash() const {
    using htype = uint64_t;
    return hash_combine(
        hash_value<htype>(name_expressions_),
        hash_combine(hash_value<htype>(extra_header_sources_),
                     hash_combine(hash_value<htype>(extra_compiler_options_),
                                  hash_value<htype>(extra_linker_options_))));
  }

  struct Hash {
    size_t operator()(const AutoKey& x) const { return x.hash(); }
  };

  // This is used by to_filename().
  // TODO: This should really be a custom to_filename() method instead.
  friend std::ostream& operator<<(std::ostream& stream, const AutoKey& key) {
    // We write a 256-bit hash value instead of the full data because filenames
    // are limited in length.
    auto sorted_iters = [](const StringMap& m) {
      std::vector<StringMap::const_iterator> iters;
      iters.reserve(m.size());
      for (StringMap::const_iterator it = m.begin(); it != m.end(); ++it) {
        iters.push_back(it);
      }
      std::sort(
          iters.begin(), iters.end(),
          [](StringMap::const_iterator lhs, StringMap::const_iterator rhs) {
            return lhs->first < rhs->first;
          });
      return iters;
    };
    std::string key_str;
    key_str += std::to_string(key.extra_header_sources_.size());
    key_str += '\0';
    for (const auto iter : sorted_iters(key.extra_header_sources_)) {
      key_str += iter->first;
      key_str += '\0';
      key_str += iter->second;
      key_str += '\0';
    }
    key_str += '\0';
    for (const StringVec& vec :
         {key.name_expressions_, key.extra_compiler_options_,
          key.extra_linker_options_}) {
      key_str += std::to_string(vec.size());
      key_str += '\0';
      for (const std::string& str : vec) {
        key_str += str;
        key_str += '\0';
      }
      key_str += '\0';
    }
    return stream << sha256(key_str);
  }
};

template <typename T>
struct default_hasher {
  using type = std::hash<T>;
};
template <>
struct default_hasher<AutoKey> {
  using type = AutoKey::Hash;
};

}  // namespace detail

template <typename Key = detail::AutoKey,
          typename Hash = typename detail::default_hasher<Key>::type,
          typename KeyEqual = std::equal_to<Key>,
          typename KeyToFilename = detail::StreamToString<Key>>
class ProgramCache {
 public:
  using key_type = Key;
  using hasher = Hash;
  using key_equal = KeyEqual;
  using key_to_filename = KeyToFilename;
  using value_type = LoadedProgram;

 private:
  using combined_key_type = detail::KeyWithUInt64<Key, Hash, KeyEqual>;
  using combined_hasher = typename combined_key_type::Hash;
  using combined_key_equal = std::equal_to<combined_key_type>;

  PreprocessedProgramData preprog_;
  const StringMap* shared_headers_ref_;
  detail::LRUCache<combined_key_type, value_type, combined_hasher,
                   combined_key_equal>
      mem_cache_;
  detail::LRUFileCache file_cache_;
  hasher hash_;
  key_equal equal_;
  key_to_filename to_filename_;
  JITIFY_IF_THREAD_SAFE(mutable std::mutex mutex_;)
  size_t num_hits_ = 0;
  size_t num_misses_ = 0;

  StringVec merge_compiler_options(StringVec extra_compiler_options) const {
    extra_compiler_options.insert(extra_compiler_options.begin(),
                                  preprog_.remaining_compiler_options().begin(),
                                  preprog_.remaining_compiler_options().end());
    return extra_compiler_options;
  }

  StringVec merge_linker_options(StringVec extra_linker_options) const {
    extra_linker_options.insert(extra_linker_options.begin(),
                                preprog_.remaining_linker_options().begin(),
                                preprog_.remaining_linker_options().end());
    return extra_linker_options;
  }

  const StringMap& merge_header_sources(
      const StringMap& extra_header_sources,
      StringMap* tmp_merged_header_sources) const {
    StringMap empty_shared_headers;
    const StringMap& shared_headers =
        shared_headers_ref_ ? *shared_headers_ref_ : empty_shared_headers;
    return detail::merge(
        detail::merge(shared_headers, preprog_.header_sources(),
                      tmp_merged_header_sources),
        extra_header_sources, tmp_merged_header_sources);
  }

  LinkedProgram build_linked_program(const StringVec& name_expressions,
                                     const StringMap& extra_header_sources,
                                     StringVec extra_compiler_options,
                                     StringVec extra_linker_options) const {
    StringMap tmp_all_header_sources;
    const StringMap& all_header_sources =
        merge_header_sources(extra_header_sources, &tmp_all_header_sources);
    StringVec all_compiler_options =
        merge_compiler_options(extra_compiler_options);
    StringVec all_linker_options = merge_linker_options(extra_linker_options);

    auto compiled = CompiledProgram::compile(
        preprog_.name(), preprog_.source(), all_header_sources,
        name_expressions, std::move(all_compiler_options));
    if (!compiled) return LinkedProgram::Error(compiled.error());
    return compiled->link(std::move(all_linker_options));
  }

 public:
  /*! Construct a program cache.
   *
   * This class provides a way to cache pre-linked and pre-loaded programs in
   *   the filesystem and in memory (respectively), avoiding the cost of
   *   re-compiling/linking/loading programs and kernels when they are reused.
   *  \param max_in_mem The maximum number of loaded programs to keep in memory.
   *  \param preprog The preprocessed program to cache.
   *  \param shared_headers_ref (optional) Pointer to a map of headers that
   *    should be added to the preprocessed program. If provided, the pointed-to
   *    object must exist for the lifetime of this class.
   *  \param file_cache_path (optional) Path in which to store cached linked
   *    programs. If not specified, file caching is not used.
   *  \param max_files (optional) The maximum number of linked programs to keep
   *    in the file cache. Defaults to the same value as \p max_in_mem.
   *  \param hash (optional) The object to use to compute hashes of cache keys.
   *  \param equal (optional) The object to use to compare cache keys.
   *  \param to_filename (optional) The object to use to convert keys to
   *    filenames.
   *  \param file_suffix (optional) The suffix to add to files in the file
   *    cache. This is used (in combination with the program name) to uniquely
   *    identify files that are part of the cache.
   *  \see get_kernel
   */
  ProgramCache(size_t max_in_mem, PreprocessedProgramData preprog,
               const StringMap* shared_headers_ref = nullptr,
               std::string file_cache_path = {}, size_t max_files = 0,
               const hasher& hash = {}, const key_equal& equal = {},
               const key_to_filename& to_filename = {},
               const std::string& file_suffix = ".jitify")
      : preprog_(std::move(preprog)),
        shared_headers_ref_(shared_headers_ref),
        mem_cache_(max_in_mem),
        file_cache_(std::move(file_cache_path),
                    max_files ? max_files : max_in_mem,
                    /*file_prefix=*/preprog_.name() + ".", file_suffix),
        hash_(hash),
        equal_(equal),
        to_filename_(to_filename) {}

  /*! Get or build a LoadedProgram object from the cache.
   *
   * If not already in the cache, the requested program is built by compiling,
   *   linking, and loading the preprocessed program. The returned object may
   *   contain errors from any of these stages. If a file cache path was
   *   specified, the linked program may be obtained from the file cache,
   *   avoiding recompilation.
   *  \param key A value that uniquely identifies the requested program.
   *  \param name_expressions List of name expressions to include during
   *    compilation (e.g.,
   *    `{&quot;my_namespace::my_kernel<123, float>&quot;, &quot;v<7>&quot;}`).
   *  \param extra_header_sources List of additional header names and sources to
   *    include during compilation. These are added to those already specified
   *    in the preprocessed program, replacing them if names match.
   *  \param extra_compiler_options List of additional compiler options.
   *  \param extra_linker_options List of additional linker options.
   *  \return A LoadedProgram object that contains either a valid
   *    LoadedProgramData object or an error state.
   *  \see get_kernel
   */
  LoadedProgram get_program(const key_type& key,
                            const StringVec& name_expressions,
                            const StringMap& extra_header_sources = {},
                            StringVec extra_compiler_options = {},
                            StringVec extra_linker_options = {}) {
    // Add the current CUDA context to the key, as modules are context-specific.
    CUcontext context;
    CUresult cuda_ret = cuCtxGetCurrent(&context);
    if (cuda_ret != CUDA_SUCCESS) {
      return LoadedProgram::Error(detail::get_cuda_error_string(cuda_ret));
    }
    combined_key_type mem_cache_key(key, reinterpret_cast<uintptr_t>(context),
                                    hash_, equal_);

    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    auto value_and_found = mem_cache_[mem_cache_key];
    value_type* value = &value_and_found.first;
    bool found = value_and_found.second;
    if (found) {
      ++num_hits_;
    } else {
      ++num_misses_;
      // Add the SM architecture to the key, as cubins are arch-specific.
      StringVec all_compiler_options =
          merge_compiler_options(extra_compiler_options);
      std::string error;
      bool is_virtual;
      int given_cc =
          detail::parse_arch_flag(all_compiler_options, &is_virtual, &error);
      if (!error.empty()) {
        return LoadedProgram::Error("Failed to parse architecture flag: " +
                                    error);
      }
      int compute_capability;
      if (given_cc > 0 && !is_virtual) {
        compute_capability = given_cc;
      } else {
        compute_capability =
            detail::get_current_device_compute_capability(&error);
        if (!error.empty()) {
          return LoadedProgram::Error("Failed to detect device architecture: " +
                                      error);
        }
      }
      std::stringstream filename_ss;
      filename_ss << to_filename_(key) << ".sm" << compute_capability << ".v"
                  << std::hex << serialization::kSerializationVersion;

      string source_name = detail::path_join(file_cache_.get_path(), filename_ss.str() + ".cu");
      const string& cuda_source = extra_header_sources.find("GeneratedKernel.cu")->second;
      jitify2::StringMap extra_header_sources2 = { {source_name, cuda_source} };
      extra_compiler_options[0] = "-include=" + source_name;
      LinkedProgram linked;
      error = file_cache_.get(
          filename_ss.str(), &linked, source_name, cuda_source,
          [&] {
            return build_linked_program(name_expressions, extra_header_sources2,
                                        extra_compiler_options,
                                        extra_linker_options);
          },
          [&](const LinkedProgram& _linked, std::ostream& ostream) {
            if (_linked) _linked->serialize(ostream);
          },
          [&](std::istream& istream) {
            return LinkedProgram::deserialize(istream);
          });
      if (!error.empty()) return LoadedProgram::Error(error);
      if (!linked) return LoadedProgram::Error(linked.error());
      *value = linked->load();
      if (!*value) return LoadedProgram::Error(value->error());
    }
    return *value;
  }

  // Note: This overload is only enabled when Key = AutoKey.
  /*! Get or build a LoadedProgram object from the cache using an
   *  automatically-computed key.
   *  \see get_program
   */
  template <typename U = Key,
            typename std::enable_if<std::is_same<U, detail::AutoKey>::value,
                                    int>::type = 0>
  LoadedProgram get_program(const StringVec& name_expressions,
                            const StringMap& extra_header_sources = {},
                            StringVec extra_compiler_options = {},
                            StringVec extra_linker_options = {}) {
    return get_program(
        detail::AutoKey(name_expressions, extra_header_sources,
                        extra_compiler_options, extra_linker_options),
        name_expressions, extra_header_sources,
        std::move(extra_compiler_options), std::move(extra_linker_options));
  }

  /*! Get or build a Kernel object from the cache.
   *
   * If not already in the cache, the requested kernel is built by compiling,
   *   linking, and loading the preprocessed program. The returned object may
   *   contain errors from any of these stages. If a file cache path was
   *   specified, the linked program may be obtained from the file cache,
   *   avoiding recompilation.
   *  \param key A value that uniquely identifies the requested kernel.
   *  \param name The full name of the instantiated kernel (e.g.,
   *    `&quot;my_namespace::my_kernel<123, float>&quot;`).
   *  \param other_name_expressions List of other name expressions to
   *    include during compilation (e.g., global variable template
   *    instantiations).
   *  \param extra_header_sources List of additional header names and sources to
   *    include during compilation. These are added to those already specified
   *    in the preprocessed program, replacing them if names match.
   *  \param extra_compiler_options List of additional compiler options.
   *  \param extra_linker_options List of additional linker options.
   *  \return A Kernel object that contains either a valid KernelData object or
   *    an error state.
   *  \see get_program
   */
  Kernel get_kernel(const key_type& key, std::string name,
                    StringVec other_name_expressions = {},
                    const StringMap& extra_header_sources = {},
                    StringVec extra_compiler_options = {},
                    StringVec extra_linker_options = {}) {
    other_name_expressions.push_back(name);
    LoadedProgram program = get_program(
        key, other_name_expressions, extra_header_sources,
        std::move(extra_compiler_options), std::move(extra_linker_options));
    if (!program) return Kernel::Error(program.error());
    return Kernel::get_kernel(std::move(*program), std::move(name));
  }

  // Note: This overload is only enabled when Key = AutoKey.
  /*! Get or build a Kernel object from the cache using an
   *  automatically-computed key.
   *  \see get_kernel
   */
  template <typename U = Key,
            typename std::enable_if<std::is_same<U, detail::AutoKey>::value,
                                    int>::type = 0>
  Kernel get_kernel(std::string name, StringVec other_name_expressions = {},
                    const StringMap& extra_header_sources = {},
                    StringVec extra_compiler_options = {},
                    StringVec extra_linker_options = {}) {
    other_name_expressions.push_back(name);
    LoadedProgram program = get_program(
        detail::AutoKey(other_name_expressions, extra_header_sources,
                        extra_compiler_options, extra_linker_options),
        other_name_expressions, extra_header_sources,
        std::move(extra_compiler_options), std::move(extra_linker_options));
    if (!program) return Kernel::Error(program.error());
    return Kernel::get_kernel(std::move(*program), std::move(name));
  }

  /*! Get the maximum size of the memory cache. */
  size_t max_in_mem() const {
    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    return mem_cache_.max_size();
  }
  /*! Get the maximum size of the file cache. */
  size_t max_files() const {
    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    return file_cache_.max_size();
  }

  /*! Clear the memory and file caches.
   *  \return false on file deletion error.
   */
  bool clear() {
    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    mem_cache_.clear();
    return file_cache_.clear();
  }

  /*! Change the max size of the memory and file caches.
   *  \param max_in_mem The new maximum size for the memory cache.
   *  \param max_files The new maximum size for the file cache.
   *  \return false on file deletion error (max size is still changed).
   *  \note Resizing to a maximum size of 0 causes caching to be disabled.
   */
  bool resize(size_t max_in_mem, size_t max_files) {
    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    mem_cache_.resize(max_in_mem);
    return file_cache_.resize(max_files);
  }

  /*! Change the max size of the memory and file caches.
   *  \param max_size The new maximum size for both the memory and file cache.
   *  \return false on file deletion error (max size is still changed).
   *  \note Resizing to a maximum size of 0 causes caching to be disabled.
   */
  bool resize(size_t max_size) { return resize(max_size, max_size); }

  /*! Get the total number of cache hits and misses.
   *  \param num_hits Pointer to value where the total number of cache hits will
   *    be stored.
   *  \param num_misses Pointer to value where the total number of cache misses
   *    will be stored.
   *  \see reset_stats
   */
  void get_stats(size_t* num_hits, size_t* num_misses) const {
    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    *num_hits = num_hits_;
    *num_misses = num_misses_;
  }

  /*! Reset the cache hit and miss statistics to zero.
   *  \see get_stats
   */
  void reset_stats() {
    JITIFY_IF_THREAD_SAFE(std::lock_guard<std::mutex> lock(mutex_);)
    num_hits_ = 0;
    num_misses_ = 0;
  }
};

#endif  // not JITIFY_SERIALIZATION_ONLY

}  // namespace jitify2

#ifndef JITIFY_SERIALIZATION_ONLY

#undef JITIFY_PATH_MAX
#undef JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR
#undef JITIFY_THROW_OR_RETURN
#undef JITIFY_THROW_OR_TERMINATE

#if defined(_WIN32) || defined(_WIN64)
#pragma pop_macro("max")
#pragma pop_macro("min")
#pragma pop_macro("strtok_r")
#endif

#endif  // not JITIFY_SERIALIZATION_ONLY

#endif  // JITIFY2_HPP_INCLUDE_GUARD
