#!/usr/bin/env python3

from sympy import *


def output(*args, **kwargs):
    pprint(simplify(*args, **kwargs))
    print('')

def mat(*args):
    return Matrix(list(map(list, args))).T

def vec(*args):
    return Matrix(args)

def normalize(v):
    return v / sqrt(v.dot(v))

def elementwise(a, b, f):
    return Matrix([
        [
            f(a[r, c], b[r, c])
            for c in range(a.cols)
        ]
        for r in range(a.rows)
    ])

def mul(a, b):
    return elementwise(a, b, lambda a, b: a * b)

def div(a, b):
    return elementwise(a, b, lambda a, b: a / b)


def window(viewport, depth_range):
    size = vec(*viewport[2:4], depth_range[1]-depth_range[0])
    offs = vec(*viewport[0:2],                depth_range[0])
    s = S(1)/2 * size
    t = S(1)/2 * size + offs
    return Matrix([
        [s[0], 0,    0,    t[0]],
        [0,    s[1], 0,    t[1]],
        [0,    0,    s[2], t[2]],
        [0,    0,    0,    1],
    ])

def ortho(left, right, bottom, top, near, far):
    size = vec(right-left, top-bottom, far-near)
    offs = vec(     -left,    -bottom,    +near)
    s = div(vec(2, 2, -2), size)
    t = mul(s, offs) - vec(1, 1, 1)
    return Matrix([
        [s[0], 0,    0,    t[0]],
        [0,    s[1], 0,    t[1]],
        [0,    0,    s[2], t[2]],
        [0,    0,    0,    1],
    ])

def frustum(left, right, bottom, top, near, far):
    z = vec(0, 0, near+far, -1)
    w = vec(0, 0, near*far,  0)
    return ortho(left, right, bottom, top, near, far) * Matrix([
        [near, 0,    z[0], w[0]],
        [0,    near, z[1], w[1]],
        [0,    0,    z[2], w[2]],
        [0,    0,    z[3], w[3]],
    ])

def perspective(fovy, aspect, near, far):
    y = near * tan(0.5 * fovy)
    x = y * aspect
    return frustum(-x, +x, -y, +y, near, far)

def lookat(position, target, up):
    z = normalize(position - target)
    x = normalize(up.cross(z))
    y = z.cross(x)
    R_inv = mat(x, y, z).T
    t_inv = -position
    return mat(
        vec(*R_inv.col(0),    0),
        vec(*R_inv.col(1),    0),
        vec(*R_inv.col(2),    0),
        vec(*(R_inv * t_inv), 1)
    )


# https://registry.khronos.org/OpenGL-Refpages/gl4/html/glViewport.xhtml
# https://registry.khronos.org/OpenGL-Refpages/gl4/html/glDepthRange.xhtml
# https://www.khronos.org/opengl/wiki/Vertex_Post-Processing#Viewport_transform
# https://www.songho.ca/opengl/gl_viewport.html
viewport    = symbols('xoffs, yoffs, width, height')
depth_range = symbols('znear, zfar')
output(window(viewport, depth_range))
# ⎡width                        width         ⎤
# ⎢─────    0          0        ───── + xoffs ⎥
# ⎢  2                            2           ⎥
# ⎢                                           ⎥
# ⎢       height                height        ⎥
# ⎢  0    ──────       0        ────── + yoffs⎥
# ⎢         2                     2           ⎥
# ⎢                                           ⎥
# ⎢               zfar   znear   zfar   znear ⎥
# ⎢  0      0     ──── - ─────   ──── + ───── ⎥
# ⎢                2       2      2       2   ⎥
# ⎢                                           ⎥
# ⎣  0      0          0              1       ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/glOrtho.xml
# https://www.songho.ca/opengl/gl_projectionmatrix.html#ortho
left, right, bottom, top, near, far = symbols('left, right, bottom, top, near, far')
output(ortho(left, right, bottom, top, near, far))
# ⎡    -2                                  left + right⎤
# ⎢────────────       0            0       ────────────⎥
# ⎢left - right                            left - right⎥
# ⎢                                                    ⎥
# ⎢                  -2                    bottom + top⎥
# ⎢     0        ────────────      0       ────────────⎥
# ⎢              bottom - top              bottom - top⎥
# ⎢                                                    ⎥
# ⎢                               -2       -far - near ⎥
# ⎢     0             0        ──────────  ─────────── ⎥
# ⎢                            far - near   far - near ⎥
# ⎢                                                    ⎥
# ⎣     0             0            0            1      ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/glFrustum.xml
# https://www.songho.ca/opengl/gl_projectionmatrix.html#perspective
left, right, bottom, top, near, far = symbols('left, right, bottom, top, near, far')
output(frustum(left, right, bottom, top, near, far))
# ⎡  -2⋅near                   -left - right              ⎤
# ⎢────────────       0        ─────────────       0      ⎥
# ⎢left - right                 left - right              ⎥
# ⎢                                                       ⎥
# ⎢                -2⋅near     -bottom - top              ⎥
# ⎢     0        ────────────  ─────────────       0      ⎥
# ⎢              bottom - top   bottom - top              ⎥
# ⎢                                                       ⎥
# ⎢                             -far - near   -2⋅far⋅near ⎥
# ⎢     0             0         ───────────   ────────────⎥
# ⎢                              far - near    far - near ⎥
# ⎢                                                       ⎥
# ⎣     0             0             -1             0      ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/gluPerspective.xml
# https://www.songho.ca/opengl/gl_projectionmatrix.html#fov
fovy, aspect, near, far = symbols('fovy, aspect, near, far')
output(perspective(fovy, aspect, near, far))
# ⎡         1                                                    ⎤
# ⎢────────────────────        0             0            0      ⎥
# ⎢aspect⋅tan(0.5⋅fovy)                                          ⎥
# ⎢                                                              ⎥
# ⎢                            1                                 ⎥
# ⎢         0            ─────────────       0            0      ⎥
# ⎢                      tan(0.5⋅fovy)                           ⎥
# ⎢                                                              ⎥
# ⎢                                     -far - near  -2⋅far⋅near ⎥
# ⎢         0                  0        ───────────  ────────────⎥
# ⎢                                      far - near   far - near ⎥
# ⎢                                                              ⎥
# ⎣         0                  0            -1            0      ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/gluLookAt.xml
# https://www.songho.ca/opengl/gl_camera.html#lookat
# position = Matrix(MatrixSymbol('position', 3, 1))
# target   = Matrix(MatrixSymbol('target',   3, 1))
# up       = Matrix(MatrixSymbol('up',       3, 1))
# output(lookat(position, target, up))
output(lookat(vec(0, 2, 0), vec(1, 2, 0), vec(0, 1, 0)))
# ⎡0   0  1  0 ⎤
# ⎢            ⎥
# ⎢0   1  0  -2⎥
# ⎢            ⎥
# ⎢-1  0  0  0 ⎥
# ⎢            ⎥
# ⎣0   0  0  1 ⎦


# In my experience people sometimes assign too much magic to "linearizing the
# depth", and imply it is somehow a weird quirk of the depth buffer.
# The point of going to 4D is to express transformations that are not linear in
# 3D with transformations that are linear in 4D, we do it all the time.
# Conceptually, all we're doing is `vec4 view_pos = inverse(window *
# projection) * window_pos`, a linear operation, same as ever. Of course, to
# interpret `view_pos` as an euclidean point, we need to look at `view_pos.xyz
# / view_pos.w`.
lens_inv = (window(viewport, depth_range) * perspective(fovy, aspect, near, far)).inv()
wind_pos = vec('x', 'y', 'z', 1)
view_pos = lens_inv * wind_pos
view_z   = view_pos[2] / view_pos[3]

output(view_z.subs({'znear': 0, 'zfar': 1}))
#      -far⋅near
# ────────────────────
# far - z⋅(far - near)

output(view_z.subs({'z': 'znear'}))
# -near

output(view_z.subs({'z': 'zfar'}))
# -far

plot((-view_z).subs({'znear': 0, 'zfar': 1, 'near': 0.1, 'far': 100.0}), ('z', 0, 1))