
"The abstract texture class."

from __future__ import with_statement

import OpenGL.GL as gl

import glitch
from glitch.multicontext import MultiContext

class TextureBinding(object):
    def __init__(self, ctx, unit, texture):
        self.ctx = ctx
        self.unit = unit
        self.texture = texture
        self.bound = False

    def _try_bind(self):
        if self.texture.width is None:
            return None

        self.ctx.push('textures')
        # A bit annoying that we have to set this.
        self.ctx.setdefault('textures', {})
        self.ctx['textures'][self.unit] = self.texture

        gl.glActiveTexture(gl.GL_TEXTURE0 + self.unit)
        gl.glPushAttrib(gl.GL_ENABLE_BIT | gl.GL_TEXTURE_BIT)
        gl.glEnable(gl.GL_TEXTURE_2D)
        id = self.texture.context_get(self.ctx)
        gl.glBindTexture(gl.GL_TEXTURE_2D, id)
        return id

    def __enter__(self):
        id = self._try_bind()
        self.bound = id is not None
        return id

    def __exit__(self, type, exception, traceback):
        if self.bound:
            gl.glActiveTexture(gl.GL_TEXTURE0 + self.unit)
            gl.glPopAttrib()

            self.ctx.pop()

        self.bound = False

class Texture(MultiContext):
    "An OpenGL texture."

    _context_key = 'texture'
    internal_format = gl.GL_RGBA
    format = gl.GL_RGBA
    wrap_s = gl.GL_CLAMP_TO_EDGE
    wrap_t = gl.GL_CLAMP_TO_EDGE

    def __init__(self, width, height, data):
        """
        @param width: Width.
        @param height: Height.
        @param data: Pixel data, in RGBA format.
        """

        MultiContext.__init__(self)
        self.width = width
        self.height = height
        self.data = data

        # import traceback
        # for f in traceback.format_stack(limit=10):
        #     for l in f.splitlines():
        #         print '  ' + l
        # print '--'

    def _context_create(self, ctx):
        id = gl.glGenTextures(1)
        gl.glBindTexture(gl.GL_TEXTURE_2D, id)
        self.set_parameters(ctx)
        self.upload(ctx)
        return id

    def _context_update(self, ctx, id):
        gl.glBindTexture(gl.GL_TEXTURE_2D, id)
        self.upload(ctx)
        return id

    def _context_delete(self, ctx, id):
        gl.glDeleteTextures(id)

    def bind(self, ctx, unit):
        """Bind this texture to a texture unit.

        The return value is an object supporting the C{with} protocol. E.g.::

            with texture.bind(ctx, 0):
                # Render some other nodes that should be textured.
        """

        return TextureBinding(ctx, unit, self)

    def upload(self, ctx):
        """Upload the texture data to GPU memory.

        This function assumes that the texture is currently bound to the
        active OpenGL texture unit.
        """

        # Target, level, internal format, width, height, border, format, type,
        # data.
        gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, self.internal_format, self.width,
            self.height, 0, self.format, gl.GL_UNSIGNED_BYTE, self.data)

    def set_parameters(self, ctx):
        """Set texture parameters such as wrapping and filtering.

        This function assumes that the texture is currently bound to the
        active OpenGL texture unit.
        """

        gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, self.wrap_s)
        gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, self.wrap_t)
        gl.glTexParameteri(
            gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
        gl.glTexParameteri(
            gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)

class ApplyTexture(glitch.Node):
    "Bind a texture to a texture unit."

    def __init__(self, texture, unit=0, **kw):
        glitch.Node.__init__(self, **kw)
        self.texture = texture
        self.unit = unit

    def __repr__(self):
        return '%s(%r, %r, children=%r)' % (
            self.__class__.__name__,
            self.texture,
            self.unit,
            self.children)

    def render(self, ctx):
        with self.texture.bind(ctx, self.unit):
            glitch.Node.render(self, ctx)

class TexturedSquare(glitch.Node):
    "Draw a square with texture coordinates from (0, 0, 0) to (1, 1, 0)."

    def __init__(self, x_max=1, y_max=1, **kw):
        glitch.Node.__init__(self, **kw)
        self.x_max = x_max
        self.y_max = y_max

    def draw_rect(self, ctx, x1, y1, x2, y2):
        # Draw a rectangle with a texture coordinate for each bound
        # texture for each rectangle coordinate, taking into account
        # each texture's vertical orientation.

        rect = [(x1, y1), (x1, y2), (x2, y2), (x2, y1)]
        tex_rect = [
            (0.0, 0.0),
            (0.0, self.y_max),
            (self.x_max, self.y_max),
            (self.x_max, 0.0)]
        tex_rect_flipped = [(x, self.y_max - y) for (x, y) in tex_rect]
        textures = ctx.get('textures', {})
        tex_coords = []

        for (unit, texture) in textures.iteritems():
            tex_coords.append((unit,
                tex_rect_flipped if getattr(texture, 'y_flip', False)
                     else tex_rect))

        gl.glNormal3f(0, 0, 1)
        gl.glBegin(gl.GL_QUADS)

        for (i, vertex) in enumerate(rect):
            for (unit, tc) in tex_coords:
                gl.glMultiTexCoord2f(unit, *(tc[i]))

            gl.glVertex2f(*vertex)

        gl.glEnd()

    def draw(self, ctx):
        self.draw_rect(ctx, 0, 0, 1, 1)

class TexturedRectangle(TexturedSquare):
    "Like TextureSquare, but maintains aspect ratio."

    def draw(self, ctx):
        textures = ctx.get('textures', {})
        units = sorted(textures.keys())

        if units:
            texture = textures[units[0]]
            (tw, th) = (texture.width, texture.height)
        else:
            (tw, th) = (1, 1)

        if tw > th:
            h = float(th) / tw / 2
            (x1, x2) = (0, 1)
            (y1, y2) = (0.5 - h, 0.5 + h)
        else:
            w = float(tw) / th / 2
            (x1, x2) = (0.5 - w, 0.5 + w)
            (y1, y2) = (0, 1)

        self.draw_rect(ctx, x1, y1, x2, y2)
