import math
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    import sqlite3
    import os
    from particle import Particle
    
    n  = 0.020      # Manning の祖度係数
    Q  = 500        # 流量 [cu.m/s]
    Hd = 12         # 下流端の水位 [m]
    zd = 0          # 下流端の河床高 [m]
    dx = 200        # 距離刻み [m]
    dt = 10         # 時間刻み [s]
    So = 1 / 700    # 河床勾配
    B  = 100        # 川幅 [m]
    X  = 10000      # 水路長 [m]
    d  = 0.01       # 粒径 [cm]
    void = 0.4      # 空隙率
    D = 0.001       # 拡散係数
    
    sz = int(X / dx) + 1
    xs = np.linspace(0, X, sz)
    
    # 汎用変数
    p10_3 = 10 / 3
    q = Q / B
    qq = q * q
    qq_2g = qq / 19.6
    nnqq = n * n * qq
    dx_2 = dx / 2
    K = nnqq * dx_2
    dt_ve = dt / (1 - void)
    q_dx = 100 * q / dx # cm/s
    D_dxdx = D / (10000 * dx * dx)
    cnv = dt_ve / 100
    dt_plt = 100 * dt
    tms = np.arange(0, 24 * 3600 + dt, dt_plt)
    
    p = Particle(d)
    sgd = p.sgd
    wf = p.wf
    
    con = sqlite3.connect("result.dB")
    cur = con.cursor()
    cur.executescript("""
    DROP TABLE IF EXISTS result;
    CREATE TABLE result(
        id INTEGER PRIMARY KEY AUTOINCREMENT, t INTEGER, z REAL);
    """)
    
    def calc_qss(isFirst=False):
    
        # 境界条件
        z = zs[0]
        h = Hd - z
        E = Hd + qq_2g / h**2
        Sf = nnqq / math.pow(h, p10_3)
        ts = 98000 * h * Sf / sgd
    
        Hs[0]  = Hd
        qss[0] = p.Itakura(ts)
        cb_cs[0] = p.cb_c(h, Sf)
    
        for i, H, z in zip(range(1, sz), Hs[1:], zs[1:]):
            cnst = E + Sf * dx_2 - z
            if not isFirst: h = H - z 
            while True:
                vv_2g = qq_2g / (h * h)         # 速度水頭
                Sfdx_2 = K / math.pow(h, p10_3) # 損失水頭
                er = h + vv_2g - Sfdx_2 - cnst  # 残差
                if abs(er) < 1e-4: break
                h -= er / (1 + (p10_3 * Sfdx_2 - 2 * vv_2g) / h)
            H = h + z
            E = H + vv_2g
            Sf = Sfdx_2 / dx_2
            ts = 98000 * h * Sf / sgd
    
            Hs[i]    = H
            qss[i]   = p.Itakura(ts) # cu.cm/s/sq.cm
            cb_cs[i] = p.cb_c(h, Sf)
    
    def solve():
    
        global zs, Hs, qss, cb_cs
        zs = xs * So + zd
        Hs = np.empty(sz)
        qss = np.empty(sz)
        cb_cs = np.empty(sz)
        cs = np.zeros(sz)
        dc_dts =np.zeros(sz)
    
        for z in zs:
            cur.execute("""
                INSERT INTO result(t, z)
                VALUES(?, ?)
            """, (0, z))
        con.commit()
    
        t = dt; isFirst = True
        while t <= tms[-1]:
    
            calc_qss(isFirst)
    
            cs[-1] = qss[-1] / (wf * cb_cs[-1])
            c_up = cs[-1]
            c_dn = cs[-3]
            for i, qs, c, cb_c, H, z in zip(range(sz - 2, -1, -1),
                        qss[-2::-1], cs[-2::-1], cb_cs[-2::-1],
                        Hs[ -2::-1], zs[-2::-1]):
                qse = qs - wf * cb_c * c
                dc_dt  = (qse - q_dx * (c - c_up)) / (100 * (H - z))
                dc_dt += D_dxdx * (c_dn - 2 * c + c_up)
                dc_dts[i] = dc_dt       # 濃度の時間変化率
                zs[i]    -= cnv * qse   # 河床高の更新
                c_dn = cs[i-1] if i > 0 else -(-2 * c_dn + c)
                c_up = c
    
            cs += dt * dc_dts # 浮遊砂濃度の更新
    
            if t in tms:
                for z in zs:
                    cur.execute("""
                        INSERT INTO result(t, z)
                        VALUES(?, ?)
                    """, (t, z))
                con.commit()
    
            t += dt
            isFirst = False
    
    def update(data):
        t, xs, zs = data
        plt.cla()
        plt.plot(xs, zs)
        plt.grid()
        plt.title(f"question 1{t:10d} sec.")
        plt.xlabel("$x$ [m]")
        plt.ylabel("$z$ [m]")
    
    def set_data(xs):
    
        sql = """
            SELECT z
            FROM result
            WHERE t={}
        """
        for t in tms:
            zs = []
            for cols in cur.execute(sql.format(t)):
                zs.append(cols[0])
            yield t, xs, zs
    
    solve()
    
    fig = plt.figure()
    ani = animation.FuncAnimation(fig, update, set_data(xs),
            interval=10, repeat=False, cache_frame_data=False)
    plt.show()
    # ani.save("excercise_20.4.gif", writer = 'pillow')
    
    con.close()
    os.remove('result.dB')
download