import { loadShaderProgram } from "../shader";
import {
  createQuadMesh,
  enablePositionVertexAttribute,
  enableTexCoordsVertexAttribute,
} from "../mesh";
import {
  pieceHeightInPx,
  pieceWidthInPx,
  pieceXToClipSpaceX,
  pieceYToClipSpaceY,
  Puzzle,
  PuzzlePiece,
} from "../../puzzle";
import { mat4, vec3 } from "gl-matrix";
import { JigsawMaskShader, MASKS_PER_PAGE } from "./jigsaw-mask-shader";
import { convertRemToPixels } from "../../style";
import { groupByProperty } from "../../utils";

const OVERDRAW = 0.66;

const fragmentShaderSource = `
    precision mediump float;

    uniform sampler2D uPageTexture;
    uniform sampler2D uTexture;
    uniform float uMaskProbeOffset;

    varying vec2 vTexCoords;
    varying vec2 vMaskTexCoords;

    void main() {
        vec4 maskColorX0 = texture2D(uPageTexture, vMaskTexCoords + vec2(-uMaskProbeOffset, 0.0));
        vec4 maskColorX1 = texture2D(uPageTexture, vMaskTexCoords + vec2(uMaskProbeOffset, 0.0));
        
        vec4 maskColorY0 = texture2D(uPageTexture, vMaskTexCoords + vec2(0.0 , -uMaskProbeOffset));
        vec4 maskColorY1 = texture2D(uPageTexture, vMaskTexCoords + vec2(0.0 , uMaskProbeOffset));
        
        float diffX = abs(maskColorX0.x - maskColorX1.x);
        float diffY = abs(maskColorY0.y - maskColorY1.y);
        
        if(diffX <= 0.0125 && diffY <= 0.0133) {
          discard;
        }
        
        vec3 originalColor = texture2D(uTexture, vTexCoords).rgb;
        vec3 invertedColor =  vec3(1.0) - originalColor;
        vec3 sepiaColor = vec3(0.95, 0.9, 0.85);
        
        gl_FragColor = vec4(mix(sepiaColor, invertedColor, 0.25), 0.5);
    }
`;

const vertexShaderSource = `
    attribute vec2 aPosition;
    attribute vec2 aTexCoords;

    varying vec2 vTexCoords;
    varying vec2 vMaskTexCoords;
    
    uniform mat4 uTexCoordProjectionMatrix;
    uniform mat4 uMaskTexCoordProjectionMatrix;
    uniform mat4 uProjectionMatrix;

    void main() {
        vTexCoords = (uTexCoordProjectionMatrix * vec4(aTexCoords, 1.0, 1.0)).xy;
        vMaskTexCoords = (uMaskTexCoordProjectionMatrix * vec4(aTexCoords, 1.0, 1.0)).xy;;
        gl_Position = uProjectionMatrix * vec4(aPosition, 1.0, 1.0);
    }
`;

export interface PuzzleOutlineShader {
  render: (
    texture: WebGLTexture,
    puzzle: Puzzle,
    grid: boolean,
    draggingPuzzlePiece: PuzzlePiece | null,
    jigsawMaskShader: JigsawMaskShader,
  ) => void;
}

export const createPuzzleOutlineShader = (
  gl: WebGLRenderingContext,
): PuzzleOutlineShader => {
  // Load shader program
  const shaderProgram = loadShaderProgram(
    gl,
    vertexShaderSource,
    fragmentShaderSource,
  );

  const shaderLocations = {
    attrib: {
      position: gl.getAttribLocation(shaderProgram, "aPosition"),
      texCoords: gl.getAttribLocation(shaderProgram, "aTexCoords"),
    },
    uniform: {
      texture: gl.getUniformLocation(shaderProgram, "uTexture"),
      pageTexture: gl.getUniformLocation(shaderProgram, "uPageTexture"),
      maskProbeOffset: gl.getUniformLocation(shaderProgram, "uMaskProbeOffset"),
      projectionMatrix: gl.getUniformLocation(
        shaderProgram,
        "uProjectionMatrix",
      ),
      texCoordProjectionMatrix: gl.getUniformLocation(
        shaderProgram,
        "uTexCoordProjectionMatrix",
      ),
      maskTexCoordProjectionMatrix: gl.getUniformLocation(
        shaderProgram,
        "uMaskTexCoordProjectionMatrix",
      ),
    },
  };

  // Create quad mesh
  const quadMesh = createQuadMesh(gl);

  const projectionMatrix = mat4.create();
  const texCoordProjectionMatrix = mat4.create();
  const maskTexCoordProjectionMatrix = mat4.create();

  const render = (
    texture: WebGLTexture,
    puzzle: Puzzle,
    grid: boolean,
    draggingPuzzlePiece: PuzzlePiece | null,
    jigsawMaskShader: JigsawMaskShader,
  ) => {
    const maskProbeOffset =
      convertRemToPixels(0.05 / MASKS_PER_PAGE) /
      Math.max(pieceWidthInPx(puzzle, gl), pieceHeightInPx(puzzle, gl));

    // Draw quad
    gl.useProgram(shaderProgram);
    gl.bindBuffer(gl.ARRAY_BUFFER, quadMesh.vertexBuffer);
    enablePositionVertexAttribute(
      gl,
      shaderLocations.attrib.position,
      quadMesh,
    );
    enableTexCoordsVertexAttribute(
      gl,
      shaderLocations.attrib.texCoords,
      quadMesh,
    );

    gl.uniform1f(shaderLocations.uniform.maskProbeOffset, maskProbeOffset);

    // Bind video texture
    gl.activeTexture(gl.TEXTURE0);
    gl.bindTexture(gl.TEXTURE_2D, texture);
    gl.uniform1i(shaderLocations.uniform.texture, 0);

    gl.viewport(0, 0, gl.canvas.width, gl.canvas.height);

    const piecesToDraw = puzzle.pieces.filter(
      (piece) =>
        grid ||
        piece.locked ||
        draggingPuzzlePiece === null ||
        piece === draggingPuzzlePiece,
    );

    const piecesByPage = groupByProperty(
      piecesToDraw,
      (p) => jigsawMaskShader.getMask(p).page,
    );

    piecesByPage.forEach((pieces, page) => {
      // Bind mask texture
      gl.activeTexture(gl.TEXTURE1);
      gl.bindTexture(gl.TEXTURE_2D, page.pageTexture);
      gl.uniform1i(shaderLocations.uniform.pageTexture, 1);

      pieces.forEach((piece) => {
        const mask = jigsawMaskShader.getMask(piece);

        // Calculate position projection matrix
        const x = grid
          ? ((piece.targetColumn + 0.5) / puzzle.columns) * 2 - 1
          : pieceXToClipSpaceX(piece, gl);
        const y = grid
          ? ((piece.targetRow + 0.5) / puzzle.rows) * 2 - 1
          : pieceYToClipSpaceY(piece, gl);

        mat4.translate(
          projectionMatrix,
          mat4.create(),
          vec3.set(vec3.create(), x, y, 0),
        );

        mat4.scale(
          projectionMatrix,
          projectionMatrix,
          vec3.set(
            vec3.create(),
            (1 + OVERDRAW) / puzzle.columns,
            (1 + OVERDRAW) / puzzle.rows,
            1.0,
          ),
        );

        gl.uniformMatrix4fv(
          shaderLocations.uniform.projectionMatrix,
          false,
          projectionMatrix,
        );

        // Calculate tex coords projection matrix
        const pieceWidth = 1 / puzzle.columns;
        const pieceHeight = 1 / puzzle.rows;

        mat4.translate(
          texCoordProjectionMatrix,
          mat4.create(),
          vec3.set(
            vec3.create(),
            piece.targetColumn / puzzle.columns - pieceWidth * OVERDRAW * 0.5,
            piece.targetRow / puzzle.rows - pieceHeight * OVERDRAW * 0.5,
            0,
          ),
        );

        mat4.scale(
          texCoordProjectionMatrix,
          texCoordProjectionMatrix,
          vec3.set(
            vec3.create(),
            pieceWidth * (1 + OVERDRAW),
            pieceHeight * (1 + OVERDRAW),
            1.0,
          ),
        );

        gl.uniformMatrix4fv(
          shaderLocations.uniform.texCoordProjectionMatrix,
          false,
          texCoordProjectionMatrix,
        );

        // Calculate mask tex coords projection matrix
        mat4.translate(
          maskTexCoordProjectionMatrix,
          mat4.create(),
          vec3.fromValues(mask.x, mask.y, 0),
        );

        mat4.scale(
          maskTexCoordProjectionMatrix,
          maskTexCoordProjectionMatrix,
          vec3.fromValues(mask.width, mask.height, 1.0),
        );

        gl.uniformMatrix4fv(
          shaderLocations.uniform.maskTexCoordProjectionMatrix,
          false,
          maskTexCoordProjectionMatrix,
        );

        gl.drawArrays(gl.TRIANGLES, 0, 6);
      });
    });

    // Clean up
    gl.bindTexture(gl.TEXTURE_2D, null);
  };

  return {
    render,
  };
};
