import * as THREE from 'three';

import shaderUtils from '@three-extra/util/ShaderUtils';

export default class PointSpritesMaterial extends THREE.PointsMaterial {
    constructor( params , shaderCode) {
        super( params )
        this.transparent = true
        this.depthWrite = false 
        this.frameCount = params.frameCount || 16 

        this.params = params

        this.shaderCode = shaderCode 

        if ( params.map ) {
            params.map.wrapT = THREE.RepeatWrapping
            params.map.wrapS = THREE.RepeatWrapping
        }
    }
    onBeforeCompile( shader ) {

        
        shader.uniforms.frameCount = {
            value: new THREE.Vector2(Math.sqrt(this.frameCount), Math.sqrt(this.frameCount))
    
        }
        let uniforms = ""
        let functions = ""
        let code = ""
        if ( this.shaderCode ) {
            for ( let key in this.shaderCode.uniforms ) {
                shader.uniforms[ key ] = this.shaderCode.uniforms[ key ]
            }
            uniforms = shaderUtils.uniformsToGLSLDecl( this.shaderCode.uniforms )
            functions = this.shaderCode.functions
            code = this.shaderCode.code 
       
        console.log( uniforms ,  this.shaderCode.uniforms, this.shaderCode )
        }
        

        
       
        
    
        shader.vertexShader =
            `
            attribute float frameIndexA;
            attribute float frameIndexB;
            attribute float perc;
    
            varying vec2 offsetA;
            varying vec2 offsetB;
            varying float aniPerc;
    
            uniform vec2 frameCount;  
        ` 
            +
            shader.vertexShader.replace("void main() {",
                `void main() {
                aniPerc = perc;
                offsetA = vec2 (   mod( frameIndexA, frameCount.x )/ frameCount.x, (frameCount.y-1.)/frameCount.y + floor( frameIndexA  / frameCount.x )/ frameCount.y   );
                offsetB = vec2 (   mod( frameIndexB, frameCount.x )/ frameCount.x,  (frameCount.y-1.)/frameCount.y +  floor( frameIndexB / frameCount.x )/ frameCount.y   ) ;
                `
            )
    
        // based on THREE ShaderChunk #include <map_particle_fragment>
        const mapParticleFragment = 
        `
    
            vec2 uvI = ( uvTransform * vec3( gl_PointCoord.x, 1.0 - gl_PointCoord.y, 1 ) ).xy;
            
            
            // scale UV to sprite frame dimensions
            uvI.x *= 1. / frameCount.x;
            uvI.y *= 1. / frameCount.y;
    
            vec2 suvA = uvI;
            vec2 suvB = uvI;
           
            suvA += offsetA;
            suvB += offsetB;
    
            vec4 mapTexelA = texture2D( map,  suvA );
            vec4 mapTexelB = texture2D( map,  suvB );
    
            diffuseColor *= mix(   mapTexelA , mapTexelB ,  aniPerc);

            vec2 vUv = suvA;
    
          
            ` + code 
    
        shader.fragmentShader =
            `
            varying vec2 offsetA;
            varying vec2 offsetB;
            varying float aniPerc;
            uniform vec2 frameCount;
           
       
            `+ uniforms + functions  +
            shader.fragmentShader.replace("#include <map_particle_fragment>", mapParticleFragment)

          
    }
}
