#!/usr/bin/env python3

from sys import argv
from pathlib import Path
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

plt.ion()

n = 2
m = 3
r = 3
N = 2 ** 6
frames = 15

x, y = np.meshgrid(*(np.linspace(-1, +1, N) for _ in range(2)))
t = 0

for frame, t in enumerate(np.arange(0, 1, 1 / frames)):
    d = np.sqrt(x ** 2 + y ** 2)
    z = 1 / (1 + (r * d) ** 2) * np.sin(
        2 * np.pi * (t - m * d) - n * np.arctan2(y, x)
    )

    plt.clf()
    ax = plt.gca(projection='3d')
    # ax.plot_surface(x, y, z, rstride=1, cstride=1, cmap='inferno')
    ax.plot_surface(
        x, y, z,
        rstride=1, cstride=1,
        facecolors=cm.inferno(plt.Normalize(z.min(), z.max())(z)), shade=False,
    ).set_facecolor((0,0,0,0))
    print(z.min(), z.max())
    ax.set_zlim(-10, +10)
    if frame < frames:
        frame_str = f"{frame:0{math.ceil(math.log10(frames))}}"
        plt.savefig(f"{Path(argv[0]).stem}_{frame_str}.png")
    if plt.waitforbuttonpress(1/100):
        break

# ffmpeg -i 'pyplot_gravwaves_%02d.png' -loop 0 'pyplot_gravwaves.webp'