import { loadShaderProgram } from "../shader";
import {
  createQuadMesh,
  enablePositionVertexAttribute,
  enableTexCoordsVertexAttribute,
} from "../mesh";
import { Puzzle, PuzzlePiece } from "../../puzzle";
import { mat4, vec3 } from "gl-matrix";
import { JigsawMaskShader } from "./jigsaw-mask-shader";
import { groupByProperty } from "../../utils";

const OVERDRAW = 0.66;

const fragmentShaderSource = `
    precision mediump float;

    uniform sampler2D uTexture;
    uniform sampler2D uPageTexture;

    varying vec2 vTexCoords;
    varying vec2 vMaskTexCoords;

    void main() {
        vec4 maskColor = texture2D(uPageTexture, vMaskTexCoords);
        
        if(maskColor.r < 0.5 || vTexCoords.x < 0.0 || vTexCoords.x > 1.0 || vTexCoords.y < 0.0 || vTexCoords.y > 1.0) {
          discard;
        }
        
        gl_FragColor = texture2D(uTexture, vTexCoords);
    }
`;

const vertexShaderSource = `
    attribute vec2 aPosition;
    attribute vec2 aTexCoords;

    varying vec2 vTexCoords;
    varying vec2 vMaskTexCoords;
    
    uniform mat4 uTexCoordProjectionMatrix;
    uniform mat4 uMaskTexCoordProjectionMatrix;
    uniform mat4 uProjectionMatrix;    
    uniform float uCanvasAspect;
    uniform float uVideoAspect;

    void main() {
        vec2 pieceTexCoords = (uTexCoordProjectionMatrix * vec4(aTexCoords, 1.0, 1.0)).xy;
        
        float aspectRatio = uCanvasAspect / uVideoAspect;
        float inverseAspectRatio = 1.0 / aspectRatio;
        
        if(aspectRatio >= 1.0) {
          float remainingAspect = 1.0 - inverseAspectRatio;
          vTexCoords = pieceTexCoords * vec2(1.0, inverseAspectRatio) + vec2(0.0, remainingAspect / 2.0);
        } else {
          float remainingAspect = 1.0 - aspectRatio;
          vTexCoords = pieceTexCoords * vec2(aspectRatio, 1.0) + vec2(remainingAspect / 2.0, 0.0);
        }
    
        vMaskTexCoords = (uMaskTexCoordProjectionMatrix * vec4(aTexCoords, 1.0, 1.0)).xy;
        gl_Position = uProjectionMatrix * vec4(aPosition, 1.0, 1.0);
    }
`;

export interface PuzzleShader {
  render: (
    texture: WebGLTexture,
    textureAspect: number,
    puzzle: Puzzle,
    locked: boolean,
    draggingPuzzlePiece: PuzzlePiece | null,
    jigsawMaskShader: JigsawMaskShader,
  ) => void;
}

export const createPuzzleShader = (gl: WebGLRenderingContext): PuzzleShader => {
  // Load shader program
  const shaderProgram = loadShaderProgram(
    gl,
    vertexShaderSource,
    fragmentShaderSource,
  );

  const shaderLocations = {
    attrib: {
      position: gl.getAttribLocation(shaderProgram, "aPosition"),
      texCoords: gl.getAttribLocation(shaderProgram, "aTexCoords"),
    },
    uniform: {
      texture: gl.getUniformLocation(shaderProgram, "uTexture"),
      pageTexture: gl.getUniformLocation(shaderProgram, "uPageTexture"),
      projectionMatrix: gl.getUniformLocation(
        shaderProgram,
        "uProjectionMatrix",
      ),
      texCoordProjectionMatrix: gl.getUniformLocation(
        shaderProgram,
        "uTexCoordProjectionMatrix",
      ),
      videoAspect: gl.getUniformLocation(shaderProgram, "uVideoAspect"),
      canvasAspect: gl.getUniformLocation(shaderProgram, "uCanvasAspect"),
      maskTexCoordProjectionMatrix: gl.getUniformLocation(
        shaderProgram,
        "uMaskTexCoordProjectionMatrix",
      ),
    },
  };

  // Create quad mesh
  const quadMesh = createQuadMesh(gl);

  const projectionMatrix = mat4.create();
  const texCoordProjectionMatrix = mat4.create();
  const maskTexCoordProjectionMatrix = mat4.create();

  const render = (
    texture: WebGLTexture,
    textureAspect: number,
    puzzle: Puzzle,
    locked: boolean,
    draggingPuzzlePiece: PuzzlePiece | null,
    jigsawMaskShader: JigsawMaskShader,
  ) => {
    gl.viewport(0, 0, gl.canvas.width, gl.canvas.height);

    // Draw quad
    gl.useProgram(shaderProgram);
    enablePositionVertexAttribute(
      gl,
      shaderLocations.attrib.position,
      quadMesh,
    );
    enableTexCoordsVertexAttribute(
      gl,
      shaderLocations.attrib.texCoords,
      quadMesh,
    );

    // Bind video texture
    gl.activeTexture(gl.TEXTURE0);
    gl.bindTexture(gl.TEXTURE_2D, texture);
    gl.uniform1i(shaderLocations.uniform.texture, 0);

    const piecesToDraw = (
      draggingPuzzlePiece
        ? [...puzzle.pieces.filter((p) => p.locked), draggingPuzzlePiece]
        : puzzle.pieces
    ).filter((p) => p.locked === locked);

    const piecesByPage = groupByProperty(
      piecesToDraw,
      (p) => jigsawMaskShader.getMask(p).page,
    );

    piecesByPage.forEach((pieces, page) => {
      // Bind page texture
      gl.activeTexture(gl.TEXTURE1);
      gl.bindTexture(gl.TEXTURE_2D, page.pageTexture);
      gl.uniform1i(shaderLocations.uniform.pageTexture, 1);

      pieces.forEach((piece) => {
        const mask = jigsawMaskShader.getMask(piece);

        // Aspect
        gl.uniform1f(
          shaderLocations.uniform.canvasAspect,
          gl.canvas.width / gl.canvas.height,
        );
        gl.uniform1f(shaderLocations.uniform.videoAspect, textureAspect);

        // Calculate position projection matrix
        mat4.translate(
          projectionMatrix,
          mat4.create(),
          vec3.set(
            vec3.create(),
            (piece.canvasX / gl.canvas.width) * 2 - 1,
            (piece.canvasY / gl.canvas.height) * 2 - 1,
            0,
          ),
        );

        mat4.scale(
          projectionMatrix,
          projectionMatrix,
          vec3.set(
            vec3.create(),
            (1 + OVERDRAW) / puzzle.columns,
            (1 + OVERDRAW) / puzzle.rows,
            1.0,
          ),
        );

        gl.uniformMatrix4fv(
          shaderLocations.uniform.projectionMatrix,
          false,
          projectionMatrix,
        );

        // Calculate tex coords projection matrix
        const pieceWidth = 1 / puzzle.columns;
        const pieceHeight = 1 / puzzle.rows;

        mat4.translate(
          texCoordProjectionMatrix,
          mat4.create(),
          vec3.set(
            vec3.create(),
            piece.targetColumn / puzzle.columns - pieceWidth * OVERDRAW * 0.5,
            piece.targetRow / puzzle.rows - pieceHeight * OVERDRAW * 0.5,
            0,
          ),
        );

        mat4.scale(
          texCoordProjectionMatrix,
          texCoordProjectionMatrix,
          vec3.set(
            vec3.create(),
            pieceWidth * (1 + OVERDRAW),
            pieceHeight * (1 + OVERDRAW),
            1.0,
          ),
        );

        gl.uniformMatrix4fv(
          shaderLocations.uniform.texCoordProjectionMatrix,
          false,
          texCoordProjectionMatrix,
        );

        // Calculate mask tex coords projection matrix
        mat4.translate(
          maskTexCoordProjectionMatrix,
          mat4.create(),
          vec3.fromValues(mask.x, mask.y, 0),
        );

        mat4.scale(
          maskTexCoordProjectionMatrix,
          maskTexCoordProjectionMatrix,
          vec3.fromValues(mask.width, mask.height, 1.0),
        );

        gl.uniformMatrix4fv(
          shaderLocations.uniform.maskTexCoordProjectionMatrix,
          false,
          maskTexCoordProjectionMatrix,
        );

        gl.drawArrays(gl.TRIANGLES, 0, 6);
      });
    });

    // Clean up
    gl.bindTexture(gl.TEXTURE_2D, null);
  };

  return {
    render,
  };
};
