#include <mitsuba/core/bitmap.h>
#include <mitsuba/core/fresolver.h>
#include <mitsuba/core/fstream.h>
#include <mitsuba/core/mstream.h>
#include <mitsuba/core/plugin.h>
#include <mitsuba/core/sched.h>
#include <mitsuba/core/shvector.h>
#include <mitsuba/core/warp.h>
#include <mitsuba/core/pmf.h>
#include <mitsuba/render/texture.h>
#include <mitsuba/render/mipmap.h>
#include <mitsuba/render/bsdf.h>
#include <mitsuba/render/scene.h>
#include <mitsuba/hw/basicshader.h>

MTS_NAMESPACE_BEGIN

class TabulatedScaledBSDF : public BSDF {
public:
	TabulatedScaledBSDF(const Properties &props) : BSDF(props) {
		m_angularScaleFilename = props.getString("angularScaleFilename", "");
		m_spatialScaleFilename = props.getString("spatialScaleFilename", "");
		m_spatialInterp = props.getString("spatialInterp", "nearest");

		m_wiUseFullSphere = props.getBoolean("wiUseFullSphere", false);
		m_woUseFullSphere = props.getBoolean("woUseFullSphere", false);

		Float uvscale = props.getFloat("uvscale", 1.0f);
		m_uvScale = Vector2(
			props.getFloat("uscale", uvscale),
			props.getFloat("vscale", uvscale));
	}

	TabulatedScaledBSDF(Stream *stream, InstanceManager *manager)
		: BSDF(stream, manager) {
		m_angularScaleFilename = stream->readString();
		m_spatialScaleFilename = stream->readString();

		m_wiUseFullSphere = stream->readBool();
		m_woUseFullSphere = stream->readBool();

		m_uvScale = Vector2(stream);

		configure();
	}

	void serialize(Stream *stream, InstanceManager *manager) const {
		BSDF::serialize(stream, manager);

		stream->writeString(m_angularScaleFilename);
		stream->writeString(m_spatialScaleFilename);

		stream->writeBool(m_wiUseFullSphere);
		stream->writeBool(m_woUseFullSphere);

		m_uvScale.serialize(stream);
	}

	void configure() {
		m_components.clear();
		m_components.push_back(EGlossyReflection | EFrontSide | ESpatiallyVarying);

		m_usesRayDifferentials = false;

		// load angular scales
		if (m_angularScaleFilename != "")
			m_angularScales = new Bitmap(fs::path(m_angularScaleFilename));
		else
			m_angularScales = NULL;

		if (m_angularScales != NULL) {
			m_lobeSize = m_angularScales->getSize();

			if (m_wiUseFullSphere)
				m_wiResolution = math::floorToInt(std::sqrt((Float)m_lobeSize.y * 0.5));
			else
				m_wiResolution = math::floorToInt(std::sqrt((Float)m_lobeSize.y));

			if (m_woUseFullSphere)
				m_woResolution = math::floorToInt(std::sqrt((Float)m_lobeSize.x * 0.5));
			else
				m_woResolution = math::floorToInt(std::sqrt((Float)m_lobeSize.x));

			Log(EInfo, "wiRes = %d, woRes = %d", m_wiResolution, m_woResolution);
		}

		// load spatial scales
		if (m_spatialScaleFilename != "")
			m_spatialScales = new Bitmap(fs::path(m_spatialScaleFilename));
		else
			m_spatialScales = NULL;

		if (m_spatialScales != NULL) {
			m_xyResolution = m_spatialScales->getSize();
			Log(EInfo, "xRes = %d, yRes = %d", m_xyResolution.x, m_xyResolution.y);
		}

		BSDF::configure();
	}

	Spectrum evalSpatialScale(const BSDFSamplingRecord &bRec) const {
		Point2 uv = transformUV(bRec.its.uv);
		
		if (m_spatialInterp == "nearest") {
			int xIdx = math::floorToInt(uv.x * m_xyResolution.x);
			int yIdx = math::floorToInt(uv.y * m_xyResolution.y);
			return m_spatialScales->getPixel(Point2i(xIdx, yIdx));
		} else if (m_spatialInterp == "bilinear") {
			Float x = uv.x * m_xyResolution.x;
			Float y = uv.y * m_xyResolution.y;

			// grid point at (x+0.5, y+0.5)
			int xIdx = math::floorToInt(x + 0.5);
			int yIdx = math::floorToInt(y + 0.5);

			Spectrum res(0.f);

			for (int dy = -1; dy <= 0; dy++) {
				Float wv = 1.0 - std::abs(y - (yIdx + dy + 0.5));
				int yNow = (yIdx + m_xyResolution.y + dy) % m_xyResolution.y;
				for (int dx = -1; dx <= 0; dx++) {
					Float wu = 1.0 - std::abs(x - (xIdx + dx + 0.5));
					int xNow = (xIdx + m_xyResolution.x + dx) % m_xyResolution.x;
					Spectrum tmpValue = m_spatialScales->getPixel(Point2i(xNow, yNow));
					res += tmpValue * wu * wv;
				}
			}
			return res;
		}		
	}

	Spectrum evalScale(const BSDFSamplingRecord &bRec) const {
		Vector wiWorld = bRec.its.toWorld(bRec.wi);
		Vector wiMacro = bRec.its.baseFrame.toLocal(wiWorld);
		Vector woWorld = bRec.its.toWorld(bRec.wo);
		Vector woMacro = bRec.its.baseFrame.toLocal(woWorld);


		int r1Offset = 0;
		if (wiMacro.z <= 0) {
			if (!m_wiUseFullSphere)
				return Spectrum(0.0);
			r1Offset = m_wiResolution;
			wiMacro.z = -wiMacro.z;
		}

		int r2Offset = 0;
		if (woMacro.z <= 0) {
			if (!m_woUseFullSphere)
				return Spectrum(0.0);
			r2Offset = m_woResolution;
			woMacro.z = -woMacro.z;
		}

		Point2 wiTex = warp::uniformHemisphereToSquareConcentric(wiMacro);
		Point2 woTex = warp::uniformHemisphereToSquareConcentric(woMacro);

		// piecewise bilinear
		int wiNumCells = m_wiResolution - 1;
		int woNumCells = m_woResolution - 1;

		int c1 = math::clamp(math::floorToInt(wiTex.x * wiNumCells), 0, wiNumCells - 1);
		int r1 = math::clamp(math::floorToInt(wiTex.y * wiNumCells), 0, wiNumCells - 1) + r1Offset;
		int c2 = math::clamp(math::floorToInt(woTex.x * woNumCells), 0, woNumCells - 1);
		int r2 = math::clamp(math::floorToInt(woTex.y * woNumCells), 0, woNumCells - 1) + r2Offset;

		Spectrum res(0.f);
		Float w(0.f);
		for (int dr1 = 0; dr1 < 2; dr1++) {
			Float v1 = wiTex.y * wiNumCells - (r1 % m_wiResolution);
			Float wv1 = std::abs(1.0 - dr1 - v1);
			
			for (int dc1 = 0; dc1 < 2; dc1++) {
				Float u1 = wiTex.x * wiNumCells - c1;
				Float wu1 = std::abs(1.0 - dc1 - u1);
				
				for (int dr2 = 0; dr2 < 2; dr2++) {
					Float v2 = woTex.y * woNumCells - (r2 % m_woResolution);
					Float wv2 = std::abs(1.0 - dr2 - v2);
					
					for (int dc2 = 0; dc2 < 2; dc2++) {
						Float u2 = woTex.x * woNumCells - c2;
						Float wu2 = std::abs(1.0 - dc2 - u2);

						int wiIdx = (r1 + dr1) * m_wiResolution + (c1 + dc1);
						int woIdx = (r2 + dr2) * m_woResolution + (c2 + dc2);

						Spectrum tmpValue = m_angularScales->getPixel(Point2i(woIdx, wiIdx));
						res += tmpValue * wv1 * wu1 * wv2 * wu2;
						w += wv1 * wu1 * wv2 * wu2;
					}
				}
			}
		}

		return res;
	}

	Spectrum eval(const BSDFSamplingRecord &bRec, EMeasure measure) const {
		Spectrum s(1.f), t(1.f);
		if (m_angularScales != NULL)
			s = evalScale(bRec);
		if (m_spatialScales != NULL)
			t = evalSpatialScale(bRec);
		if (s.isZero() || t.isZero())
			return Spectrum(0.f);
		Spectrum spec = m_bsdf->eval(bRec, measure);
		return spec * s * t;
	}


	Float pdf(const BSDFSamplingRecord &bRec, EMeasure measure) const {
		return m_bsdf->pdf(bRec, measure);
	}

	Spectrum sample(BSDFSamplingRecord &bRec, Float &pdf, const Point2 &sample) const {
        Spectrum spec = m_bsdf->sample(bRec, pdf, sample);
        if (spec.isZero())
            return Spectrum(0.f);

        Spectrum s(1.f), t(1.f);

        if (m_angularScales != NULL) {
            s = evalScale(bRec);
        }

        if (m_spatialScales != NULL) {
            t = evalSpatialScale(bRec);
        }

        return spec * s * t;
	}

	Spectrum sample(BSDFSamplingRecord &bRec, const Point2 &sp) const {
		Float pdf;
		return sample(bRec, pdf, sp);
	}

	void addChild(const std::string &name, ConfigurableObject *child) {
		if (child->getClass()->derivesFrom(MTS_CLASS(BSDF))) {
			m_bsdf = static_cast<BSDF *>(child);
		}
		else {
			BSDF::addChild(name, child);
		}
	}

	std::string toString() const {
		std::ostringstream oss;
		oss << "TabulatedScaledBSDF[" << endl
			<< "  filename = \"" << m_angularScaleFilename << "\"," << endl;
		oss << "]";
		return oss.str();
	}

	inline Point2 transformUV(const Point2 &_uv) const {
		Point2 uv(_uv);
		uv.x *= m_uvScale.x;
		uv.y *= m_uvScale.y;
		uv.x = uv.x - math::floorToInt(uv.x);
		uv.y = uv.y - math::floorToInt(uv.y);
		return uv;
	}

	Float getRoughness(const Intersection &its, int component) const {
		return m_bsdf->getRoughness(its, component);
	}

	inline Float miWeight(Float pdfA, Float pdfB) const {
		pdfA *= pdfA;
		pdfB *= pdfB;
		return pdfA / (pdfA + pdfB);
	}

	MTS_DECLARE_CLASS()
public:
	ref<Bitmap> m_angularScales;
	ref<Bitmap> m_spatialScales;
	ref<BSDF> m_bsdf;
	std::string m_angularScaleFilename;
	std::string m_spatialScaleFilename;
	std::string m_spatialInterp;
	Vector2i m_lobeSize;
	
	Vector2i m_xyResolution;
	Vector2 m_uvScale;

	int m_wiResolution;
	int m_woResolution;
	
	bool m_wiUseFullSphere;
	bool m_woUseFullSphere;

	ref_vector<Sampler> m_samplers;
};

MTS_IMPLEMENT_CLASS_S(TabulatedScaledBSDF, false, BSDF)
MTS_EXPORT_PLUGIN(TabulatedScaledBSDF, "Tabulated scaled BSDF");
MTS_NAMESPACE_END
