import math
    import numpy as np
    import matplotlib.pyplot as plt
    from particle import Particle
    
    n  = 0.020      # Manning の祖度係数
    Q  = 500        # 流量 [cu.m/s]
    Hd = 12         # 下流端の水位 [m]
    zd = 0          # 下流端の河床高 [m]
    dx = 200        # 距離刻み [m]
    dt = 10         # 時間刻み [s]
    So = 1 / 700    # 河床勾配
    B  = 100        # 川幅 [m]
    X  = 10000      # 水路長 [m]
    d  = 0.01       # 粒径 [cm]
    void = 0.4      # 空隙率
    D = 0.001       # 拡散係数
    
    sz = int(X / dx) + 1
    xs = np.linspace(0, X, sz)
    
    # 汎用変数
    p10_3 = 10 / 3
    q = Q / B
    qq = q * q
    qq_2g = qq / 19.6
    nnqq = n * n * qq
    dx_2 = dx / 2
    K = nnqq * dx_2
    dt_ve = dt / (1 - void)
    q_dx = 100 * q / dx # cm/s
    D_dxdx = D / (10000 * dx * dx)
    cnv = dt_ve / 100
    tms_plt = [t * 3600 for t in [1, 4, 12, 24]]
    
    p = Particle(d)
    sgd = p.sgd
    wf = p.wf
    
    def calc_qss(isFirst=False):
    
        # 境界条件
        h = Hd - zs[0]
        E = Hd + qq_2g / h**2
        Sf = nnqq / math.pow(h, p10_3)
        ts = 98000 * h * Sf / sgd
    
        Hs[0]  = Hd
        qss[0] = p.Itakura(ts)
        cb_cs[0] = p.cb_c(h, Sf)
    
        for i, H, z in zip(range(1, sz), Hs[1:], zs[1:]):
            cnst = E + Sf * dx_2 - z
            if not isFirst: h = H - z 
            while True:
                vv_2g = qq_2g / (h * h)         # 速度水頭
                Sfdx_2 = K / math.pow(h, p10_3) # 損失水頭
                er = h + vv_2g - Sfdx_2 - cnst  # 残差
                if abs(er) < 1e-4: break
                # ニュートン法による補正
                h -= er / (1 + (p10_3 * Sfdx_2 - 2 * vv_2g) / h)
            H = h + z
            E = H + vv_2g
            Sf = Sfdx_2 / dx_2
            ts = 98000 * h * Sf / sgd
    
            Hs[i]    = H
            qss[i]   = p.Itakura(ts) # cu.cm/s/sq.cm
            cb_cs[i] = p.cb_c(h, Sf)
    
    def solve(question):
    
        global zs, Hs, qss, cb_cs
        zs = xs * So + zd
        Hs = np.empty(sz)
        qss = np.empty(sz)
        cb_cs = np.empty(sz)
    
        if question < 3:
            cs = np.zeros(sz)
            dc_dts =np.zeros(sz)
        if question > 1:
            plt.cla()
        plt.plot(xs, zs, label="initial")
    
        t = dt; isFirst = True
        while t <= tms_plt[-1]:
    
            calc_qss(isFirst)
    
            if question < 3:
                cs[-1] = qss[-1] / (wf * cb_cs[-1])
                c_up = cs[-1]
                c_dn = cs[-3]
                for i, qs, c, cb_c, H, z in zip(range(sz - 2, -1, -1),
                            qss[-2::-1], cs[-2::-1], cb_cs[-2::-1],
                            Hs[ -2::-1], zs[-2::-1]):
                    qse = qs - wf * cb_c * c
                    dc_dt = (qse - q_dx * (c - c_up)) / (100 * (H - z))
                    if question == 1:
                        dc_dt += D_dxdx * (c_dn - 2 * c + c_up)
                    dc_dts[i] = dc_dt       # 濃度の時間変化率
                    zs[i]    -= cnv * qse   # 河床高の更新
                    c_dn = cs[i-1] if i > 0 else -(-2 * c_dn + c)
                    c_up = c
    
                cs += dt * dc_dts # 浮遊砂濃度の更新
            else:
                c_up = qss[-1] / (wf * cb_cs[-1])
                for i, qs, cb_c in zip(range(sz - 2, -1, -1),
                                qss[-2::-1], cb_cs[-2::-1]):
                    wfcb_c = wf * cb_c
                    c = (qs + c_up * q_dx) / (wfcb_c + q_dx)
                    zs[i] -= cnv * (qs - wfcb_c * c)
                    c_up = c
    
            if t in tms_plt:
                lbl = f"{t // 3600} hr"
                plt.plot(xs, zs, label=lbl)
    
            t += dt
            isFirst = False
    
        plt.title(f"question {question}")
        plt.xlabel("$x$ (m)")
        plt.ylabel("$z_b$ (m)")
        plt.legend(loc="lower right")
        plt.grid()
        plt.show()
    
    solve(1)
    solve(2) # 拡散項なし
    solve(3) # 非定常項、拡散項なし
download