%{
    @author: Fujun Luan
    @email : fl356@cornell.edu
%}

function plies = SimulatePly(params)
    plies = cell(1, params.ply_num);
    plyNum = double(params.ply_num);
    
    if params.use_flyaways
        rVals = cell(1, params.ply_num);
    end
    
    for plyId = 1 : plyNum
        angle = 2.0 * pi * (plyId-1) / plyNum;
        ply.base_theta = angle;
        ply.base_center = [params.yarn_radius/2.0 * cos(angle), params.yarn_radius/2.0 * sin(angle), 0];
       
        ply.fibers = cell(1, params.fiber_num);
        fiberNum = double(params.fiber_num);
        
        if params.use_flyaways
            rVals{plyId} = cell(1, params.fiber_num);
        end
        
        % Perturb fibers 
        PERTURB_FIBER_PROB = 0.9;
        PERTURB_FIBER_RATIO = 0.25;
        PERTURB_FIBER_SMOOTHING = 3;        
        
        for fiberId = 1 : fiberNum
            ns = double(params.z_step_num);
            perturbRatios = ones(1, ns);
            
            eventLoc = [];
            for step_id = 1 : ns
                if rand() < PERTURB_FIBER_PROB
                   eventLoc = [eventLoc step_id]; 
                end
            end
            
            nLoc = length(eventLoc);
            
            if nLoc > 0
                it = eventLoc(1);
                perturbRatios(it) = 1.0 + PERTURB_FIBER_RATIO * (rand() - 0.5);
                for j = 1 : it-1
                    perturbRatios(j) = perturbRatios(it);
                end
                
                for i = 2 : nLoc
                    it = eventLoc(i);    prev_it = eventLoc(i-1);
                    perturbRatios(it) = 1.0 + PERTURB_FIBER_RATIO * (rand() - 0.5);
                    extent = it - prev_it;
                    for j = prev_it+1 : it-1
                        v = (it - j) / extent;
                        v = sin(0.5 * pi * v);
                        perturbRatios(j) = perturbRatios(prev_it) * v + perturbRatios(it) * (1.0 - v);
                    end
                end
                
                it = eventLoc(nLoc);
                for j = it+1 : ns
                    perturbRatios(j) = perturbRatios(it);
                end
                
                for s = 1 : PERTURB_FIBER_SMOOTHING
                   perturbRatios0 = perturbRatios;
                   for k = 2 : ns-1
                       perturbRatios(k) = 0.25 * perturbRatios0(k-1) + 0.5 * perturbRatios0(k) + 0.25 * perturbRatios0(k+1);
                   end
                end
            end
            
        end
        
        
        for fiberId = 1 : fiberNum
            ns = double(params.z_step_num); ss = params.z_step_size;
            fiber.init_radius = sampleR(params);
            fiber.init_theta = 2.0 * pi * rand();
            fiber.init_migration_theta = 2.0 * pi * rand();
            fiber.vertices = cell(1, ns);
            
            rValsVec = zeros(1, ns);
            for step_id = 1 : ns
                z = ss * (step_id - 1 - ns / 2.0);
                fiber_theta = iif(params.fiber_clock_wise, -z * 2.0 * pi / params.alpha, z * 2.0 * pi / params.alpha);
                [local_x, local_y] = helixXYZ(fiber.init_radius, fiber.init_theta, fiber_theta, ...
                    params.use_migration, params.rho_max, params.rho_min, params.s_i, fiber.init_migration_theta);
                
                balance_radius = sqrt(params.ellipse_long * params.ellipse_short);
                local_x = local_x * balance_radius;
                local_y = local_y * balance_radius;
                
                if params.use_flyaways
                    rValsVec(step_id) = sqrt(local_x * local_x + local_y * local_y);
                end
                
                local_x = local_x * perturbRatios(step_id);
                local_y = local_y * perturbRatios(step_id);
                
                world_x = local_x;
                world_y = local_y;
                
                vertex = [world_x, world_y, z];
                fiber.vertices{step_id} = vertex;
            end
            
            ply.fibers{fiberId} = fiber;
            
            if params.use_flyaways
                rVals{plyId}{fiberId} = rValsVec;
            end
            
        end
        
        
        if params.use_flyaways
            sig_scale_hair = 0.75;
            sig_scale_loop = 0.5;
            min_loop_span = 10;
            
            zextent = params.aabb_max(3) - params.aabb_min(3);
            
            nloop = floor(params.flyaway_loop_density * zextent + 0.5);
            if nloop > 0
                locs = [];
                fiberNum = double(params.fiber_num);
                for fiberId = 1 : fiberNum
                    rValsVec = rVals{plyId}{fiberId};
                    fiber = ply.fibers{fiberId};
                    totVtx = length(fiber.vertices);
                    for k = 2 : totVtx-1
                         if rValsVec(k) > rValsVec(k-1) && rValsVec(k) > rValsVec(k+1)
                             loc = [fiberId, k];
                             locs = [locs; loc];
                         end
                    end
                end
                
                locs = random_shuffle(locs);
                
                for j = 1 : min(nloop, length(locs))
                    fid = locs(j, 1);
                    curRs = rVals{plyId}{fid};
                    fiber = ply.fibers{fid};
                    totVtx = length(fiber.vertices);
                    
                    k = locs(j, 2);
                    k0 = max(k - min_loop_span, 1.0);
                    k1 = min(k + min_loop_span, totVtx);
                    
                    while k0 > 1
                        if curRs(k0-1) >= curRs(k0)
                            break;
                        end
                        k0 = k0 - 1; 
                    end
                    
                    while k1 < totVtx
                        if curRs(k1+1) >= curRs(k1)
                            break; 
                        end
                        k1 = k1 + 1; 
                    end
                    
                    while true
                        r1 = normrnd(params.flyaway_loop_r1(1), sig_scale_loop * params.flyaway_loop_r1(2));
                        if r1 > 1.05 * curRs(k)
                            break; 
                        end
                    end
                    
                    ratio = r1 / curRs(k);
                    for t = k0+1 : k
                        v = (t - k0) / (k - k0);
                        v = 1.0 + (ratio-1.0) * sin(0.5*pi*v);
                        fiber.vertices{t}(1) = fiber.vertices{t}(1) * v;
                        fiber.vertices{t}(2) = fiber.vertices{t}(2) * v;
                    end
                    for t = k1-1 : -1 : k
                        v = (t - k1) / (k - k1);
                        v = 1.0 + (ratio-1.0) * sin(0.5*pi*v);
                        fiber.vertices{t}(1) = fiber.vertices{t}(1) * v;
                        fiber.vertices{t}(2) = fiber.vertices{t}(2) * v;
                    end
                    
                end
                
            end
            
            nhair = floor(params.flyaway_hair_density * zextent + 0.5);
            if nhair > 0
                hair_fiber_idx = 1;
                for j = 1 : nhair
                    z0 = params.aabb_min(3) + zextent * rand();
                    ze = normrnd(params.flyaway_hair_ze(1), sig_scale_hair * params.flyaway_hair_ze(2));
                    r0 = normrnd(params.flyaway_hair_r0(1), sig_scale_hair * params.flyaway_hair_r0(2));
                    re = normrnd(params.flyaway_hair_re(1), sig_scale_hair * params.flyaway_hair_re(2));
                    p0 = 2.0 * pi * rand();
                    pe = normrnd(params.flyaway_hair_pe(1), sig_scale_hair * params.flyaway_hair_pe(2));
                    
                    r0_e = 0.0; re_e = r0 + re;
                    z0_e = z0 - ze*r0/re; ze_e = ze + ze*r0/re;
                    p0_e = p0 - pe*r0/re; pe_e = pe + pe*r0/re;
                    
                    nstep = 100;
                    vars = [];
                    for k = 1 : nstep
                        x = r0_e + re_e * k / nstep;
                        y = z0_e + ze_e * k / nstep;
                        z = p0_e + pe_e * k / nstep;
                        cur = [x y z];
                        vars = [vars; cur];
                    end
                    
                    PERTURB_FIBER_PROB_HAIR = 0.2;
                    PERTURB_FIBER_RATIO_HAIR = 0.1;
                    PERTURB_FIBER_SMOOTHING_HAIR = 3;
                    
                    perturbRatios = ones(1, nstep);
                    eventLoc = [];
                    for step_id = 1 : nstep
                        if rand() < PERTURB_FIBER_PROB_HAIR
                           eventLoc = [eventLoc step_id]; 
                        end
                    end
                    
                    nLoc = length(eventLoc);
                    
                    if nLoc > 0
                        it = eventLoc(1);
                        perturbRatios(it) = 1.0 + PERTURB_FIBER_RATIO_HAIR * (rand() - 0.5);
                        for jj = 1 : it-1
                            perturbRatios(jj) = perturbRatios(it);
                        end
                        
                        for i = 2 : nLoc
                            it = eventLoc(i);    prev_it = eventLoc(i-1);
                            perturbRatios(it) = 1.0 + PERTURB_FIBER_RATIO_HAIR * (rand() - 0.5);
                            extent = it - prev_it;
                            for jj = prev_it+1 : it-1
                                v = (it - jj) / extent;
                                v = sin(0.5 * pi * v);
                                perturbRatios(jj) = perturbRatios(prev_it) * v + perturbRatios(it) * (1.0 - v);
                            end
                        end

                        it = eventLoc(nLoc);
                        for jj = it+1 : nstep
                            perturbRatios(jj) = perturbRatios(it);
                        end
                        
                        for s = 1 : PERTURB_FIBER_SMOOTHING_HAIR
                           perturbRatios0 = perturbRatios;
                           for k = 2 : nstep-1
                               perturbRatios(k) = 0.25 * perturbRatios0(k-1) + 0.5 * perturbRatios0(k) + 0.25 * perturbRatios0(k+1);
                           end
                        end
                        
                        for t = 1 : nstep
                            vars(t, 1) = vars(t, 1) * perturbRatios(t); 
                        end 
                    end
                    
                    hair_fiber.vertices = {};
                    
                    for k = 1 : nstep
                        cur = vars(k, :);
                        pos = [cur(1) * cos(cur(3)), cur(1) * sin(cur(3)), cur(2)];
                        if pos(3) < params.aabb_min(3) | pos(3) > params.aabb_max(3)
                            break; 
                        end
                        hair_fiber.vertices{k} = pos;
                    end
                    
                    
                    if length(hair_fiber.vertices) > 1
                        ply.fibers{fiberNum + hair_fiber_idx} = hair_fiber;  
                        hair_fiber_idx = hair_fiber_idx + 1;
                    end
                    
                end
            end
            
        end
        
        
        plies{plyId} = ply;
        
    end   
 
end

function locs_out = random_shuffle(locs_in)
    N = length(locs_in);
    p = randperm(N);
    locs_out = locs_in;
    
    for i = 1 : N
        locs_out(p(i), :) = locs_in(i, :);
    end
end

function dot_v = dot(v1, v2)
    dot_v = v1 * v2';
end

function norm_v = normalize(v) 
    norm_v = v ./ norm(v);
end

function out = iif(cond, a, b)
    if cond == 1
        out = a; 
    else 
        out = b;
    end
end

function r = helixRadius(init_r, theta, ...
    use_migration, rho_max, rho_min, s_i, init_migration_theta)
    r = init_r;
    if use_migration == 1
       r =  rho_min * init_r + (rho_max * init_r - rho_min * init_r) * 0.5 * (cos(s_i * (theta) + init_migration_theta) + 1);
    end
end

function [x y] = helixXYZ(init_r, init_theta, theta,...
    use_migration, rho_max, rho_min, s_i, init_migration_theta)
    r = helixRadius(init_r, theta, use_migration, rho_max, rho_min, s_i, init_migration_theta);
    x = r * cos(theta + init_theta);
    y = r * sin(theta + init_theta);
end

function pR = fiberDistrib(r, params) 
    eTerm = (exp(1) - exp(r / params.R_max)) / (exp(1) - 1);
    pR = (1 - 2 * params.epsilon) * (eTerm.^params.beta) + params.epsilon;
end

function R = sampleR(params) 
    while true
       R = sqrt(rand()) * params.R_max; 
       pdf = rand();
       if pdf < fiberDistrib(R, params)
          break; 
       end
    end
end