#!/usr/bin/env python3

# https://www.youtube.com/watch?v=jvPPXbo87ds&t=51m58s

import sys
from sympy import *

n = int(next(iter(sys.argv[1:]), 4))

t = Symbol('t')

T = Matrix([[t**i]              for i in range(n)])
P = Matrix([[Symbol(f'p{i}')]   for i in range(n)])
A = Matrix([[Symbol(f'a{i}{j}') for i in range(n)] for j in range(n)])

B = T.T * A

# B-spline
eqs = [
    *[Eq(B[0  ].diff(t, d).subs(t, 1), 0) for d in range(n-1)],
    *[Eq(B[n-1].diff(t, d).subs(t, 0), 0) for d in range(n-1)],
    *[
        Eq(
            B[i  ].diff(t, d).subs(t, 0),
            B[i+1].diff(t, d).subs(t, 1),
        )
        for i in range(n-1)
        for d in range(n-1)
    ],
    Eq(sum(B).subs(t, 0), 1),
]

sol = solve(eqs, dict=True)[0]

A = A.subs(sol)
B = B.subs(sol)

# expr = MatMul(T.T, A, P)
den = lcm([fraction(a)[1] for a in A])
expr = MatMul(T.T, 1/den, den*A, P)

print(pretty(expr))
# with open('doc/expr.txt', 'w') as f:
#     print(pretty(expr), file=f)

# TODO: Apparently there is some efficient way to evaluate a polynomial at
# equidistant points, but I don't know what it is (yet).
expr = expr.doit()[0]
calc = "\n\n".join(
    f"{c}\n{count_ops(c, visual=True)}"
    for c in
    [
        expr,
        collect(expand(expr), t),
        horner(expr, t),
        together(horner(expr, t)),
    ]
)
print()
print(calc)
# with open('doc/calc.txt', 'w') as f:
#     print(calc, file=f)

p = plot(*reversed(B), sum(B), (t, 0, 1))
# p.save('doc/basis.svg')

p = plot(show=False)
for i, b in enumerate(reversed(B)):
    p.extend(plot(b.subs(t, t-i), (t, i, i+1), show=False))
p.show()
# p.save('doc/basis_stiched.svg')