%% Copyright (C) 2025 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
%% Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
%%
%% All Rights Reserved
%%
%% SPDX-License-Identifier: AGPL-3.0-or-later
%%
%% This program is free software; you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as
%% published by the Free Software Foundation; either version 3 of the
%% License, or (at your option) any later version.
%%
%% This program is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%% GNU General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with this program.  If not, see <https://www.gnu.org/licenses/>.
%%
%% Additional Permission under GNU AGPL version 3 section 7:
%%
%% If you modify this Program, or any covered work, by linking or
%% combining it with runtime libraries of Erlang/OTP as released by
%% Ericsson on https://www.erlang.org (or a modified version of these
%% libraries), containing parts covered by the terms of the Erlang Public
%% License (https://www.erlang.org/EPLICENSE), the licensors of this
%% Program grant you additional permission to convey the resulting work
%% without the need to license the runtime libraries of Erlang/OTP under
%% the GNU Affero General Public License. Corresponding Source for a
%% non-source form of such a combination shall include the source code
%% for the parts of the runtime libraries of Erlang/OTP used as well as
%% that of the covered work.

-module(gtpu_kpi).
-behaviour(gen_server).

-export([init/1,
         handle_info/2,
         handle_call/3,
         handle_cast/2,
         terminate/2]).
-export([start_link/1,
         enb_register/1,
         enb_set_addr/1,
         enb_unregister/0,
         fetch_counters/0,
         shutdown/0]).

-include_lib("kernel/include/logger.hrl").

-include("s1gw_metrics.hrl").

-define(GTPU_PORT, 2152).

-define(OP_EQ,   "==").
-define(OP_NEQ,  "!=").


-type cfg() :: #{enable => boolean(),
                 table_name => string(),
                 interval => non_neg_integer() %% like X34 in osmo-hnbgw
                }.

-record(ctr, {packets = 0 :: non_neg_integer(),
              bytes_ue = 0 :: non_neg_integer(),
              bytes_total = 0 :: non_neg_integer()
             }).

-type counter() :: #ctr{}.

-type counters() :: dict:dict(K :: string(),
                              V :: counter()).

-type uldl() :: ul | dl.

-type uldl_addr() :: {ULDL :: uldl(),
                      Addr :: string()}.

-type enb_uldl_state() :: #{uldl => uldl(),
                            addr => string(),
                            handle => integer(),
                            ctr => counter()
                           }.

-type enb_state() :: #{pid => pid(),
                       genb_id => string(),
                       mon_ref => reference(),
                       ul => enb_uldl_state(),   %% Uplink data
                       dl => enb_uldl_state()    %% Downlink data
                      }.

-type registry() :: dict:dict(K :: pid(),
                              V :: enb_state()).

-record(state, {cfg :: cfg(),
                registry :: registry()
               }).

-export_type([cfg/0,
              counter/0,
              counters/0]).


%% ------------------------------------------------------------------
%% public API
%% ------------------------------------------------------------------

-spec start_link(cfg()) -> gen_server:start_ret().
start_link(Cfg) ->
    gen_server:start_link({local, ?MODULE}, ?MODULE, [Cfg], []).


-spec enb_register(GlobalENBId) -> ok | {error, term()}
    when GlobalENBId :: string().
enb_register(GlobalENBId) ->
    gen_server:call(?MODULE, {?FUNCTION_NAME, GlobalENBId}).


-spec enb_set_addr(ULDLAddr) -> ok
    when ULDLAddr :: uldl_addr().
enb_set_addr({ul, Addr}) ->
    gen_server:cast(?MODULE, {?FUNCTION_NAME, self(), {ul, Addr}});

enb_set_addr({dl, Addr}) ->
    gen_server:cast(?MODULE, {?FUNCTION_NAME, self(), {dl, Addr}}).


-spec enb_unregister() -> ok | {error, term()}.
enb_unregister() ->
    gen_server:call(?MODULE, ?FUNCTION_NAME).


-spec fetch_counters() -> ok | {ok, counters()} | {error, term()}.
fetch_counters() ->
    gen_server:call(?MODULE, ?FUNCTION_NAME).


-spec shutdown() -> ok.
shutdown() ->
    gen_server:stop(?MODULE).


%% ------------------------------------------------------------------
%% gen_server API
%% ------------------------------------------------------------------

init([#{enable := true} = Cfg]) ->
    process_flag(trap_exit, true),
    TName = maps:get(table_name, Cfg, "osmo-s1gw"),
    Interval = maps:get(interval, Cfg, 3000),
    %% flush the table, in case it remained
    %% it may not exist, so we ignore the result
    nft_flush_table(TName),
    %% create and initialize the table
    case nft_init_table(TName) of
        ok ->
            ?LOG_INFO("NFT table ~p has been initialized", [TName]),
            spawn_link(fun() -> heartbeat(Interval) end),
            {ok, #state{cfg = Cfg#{table_name => TName,
                                   interval => Interval},
                        registry = dict:new()}};
        Error ->
            ?LOG_ERROR("NFT table ~p init failed: ~p", [TName, Error]),
            {error, Error}
    end;

%% stub mode
init([#{enable := false} = Cfg]) ->
    {ok, #state{cfg = Cfg,
                registry = dict:new()}}.


handle_call(Info, From,
            #state{cfg = #{enable := false}} = S) ->
    ?LOG_DEBUG("ignore ~p() from ~p: ~p", [?FUNCTION_NAME, From, Info]),
    {reply, ok, S};

handle_call({enb_register, GlobalENBId}, {Pid, _Ref},
            #state{registry = R0} = S) ->
    case dict:find(Pid, R0) of
        {ok, _} ->
            ?LOG_ERROR("eNB (pid ~p, ~p) is already registered",
                       [Pid, GlobalENBId]),
            {reply, {error, already_registered}, S};
        error ->
            %% keep an eye on the process being registered
            MonRef = erlang:monitor(process, Pid),
            %% add exometer counters
            enb_add_metrics(GlobalENBId),
            %% create and store an initial eNB state
            ES = #{genb_id => GlobalENBId,
                   mon_ref => MonRef,
                   pid => Pid},
            R1 = dict:store(Pid, ES, R0),
            ?LOG_INFO("eNB (pid ~p, ~p) has been registered",
                      [Pid, GlobalENBId]),
            {reply, ok, S#state{registry = R1}}
    end;

handle_call(enb_unregister, {Pid, _Ref},
            #state{cfg = Cfg, registry = R0} = S) ->
    case dict:find(Pid, R0) of
        {ok, #{genb_id := GlobalENBId,
               mon_ref := MonRef} = ES} ->
            erlang:demonitor(MonRef, [flush]),
            enb_del_nft_rules(ES, Cfg),
            R1 = dict:erase(Pid, R0),
            ?LOG_INFO("eNB (pid ~p, ~p) has been unregistered",
                      [Pid, GlobalENBId]),
            {reply, ok, S#state{registry = R1}};
        error ->
            ?LOG_ERROR("eNB (pid ~p) is *not* registered", [Pid]),
            {reply, {error, enb_not_registered}, S}
    end;

handle_call(fetch_counters, _From,
            #state{cfg = #{table_name := TName}} = S) ->
    ?LOG_DEBUG("Fetching NFT counters"),
    Cmds = [enftables:nft_cmd_list_counters(TName)],
    case nft_exec(Cmds) of
        {ok, Res} ->
            Ctrs = parse_nft_counters(Res),
            {reply, {ok, Ctrs}, S};
        {error, Error} ->
            ?LOG_ERROR("Failed to fetch NFT counters: ~p", [Error]),
            {reply, {error, Error}, S}
    end;

handle_call(Info, From, S) ->
    ?LOG_ERROR("unknown ~p() from ~p: ~p", [?FUNCTION_NAME, From, Info]),
    {reply, {error, not_implemented}, S}.


handle_cast(Info,
            #state{cfg = #{enable := false}} = S) ->
    ?LOG_DEBUG("ignore ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, S};

handle_cast({enb_set_addr, Pid, {ULDL, Addr}},
            #state{registry = R} = S0) ->
    case dict:find(Pid, R) of
        %% the given UL/DL Addr is already known
        {ok, #{genb_id := GlobalENBId,
               ULDL := #{addr := Addr}}} ->
            ?LOG_DEBUG("eNB (pid ~p, ~p): ~p address ~p is already known",
                       [Pid, GlobalENBId, ULDL, Addr]),
            {noreply, S0};
        %% the given UL/DL Addr differs from stored Addr
        {ok, #{genb_id := GlobalENBId,
               ULDL := #{addr := OldAddr}}} ->
            ?LOG_ERROR("eNB (pid ~p, ~p): ~p address ~p -> ~p change?!?",
                       [Pid, GlobalENBId, ULDL, OldAddr, Addr]),
            {noreply, S0};
        %% UL/DL state is missing => create it
        {ok, #{genb_id := GlobalENBId} = ES} ->
            ?LOG_DEBUG("eNB (pid ~p, ~p): ~p address ~p indicated, "
                       "creating NFT counters and rules",
                       [Pid, GlobalENBId, ULDL, Addr]),
            {_, S1} = enb_set_addr({ULDL, Addr}, ES, S0),
            {noreply, S1};
        error ->
            ?LOG_ERROR("eNB (pid ~p) is *not* registered", [Pid]),
            {noreply, S0}
    end;

handle_cast(report_nft_counters,
            #state{cfg = #{table_name := TName},
                   registry = R0} = S) ->
    ?LOG_DEBUG("Fetching and reporting NFT counters"),
    Cmds = [enftables:nft_cmd_list_counters(TName)],
    case nft_exec(Cmds) of
        {ok, Res} ->
            Ctrs = parse_nft_counters(Res),
            R1 = report_nft_counters(Ctrs, R0),
            {noreply, S#state{registry = R1}};
        {error, Error} ->
            ?LOG_ERROR("Failed to fetch NFT counters: ~p", [Error]),
            {noreply, S}
    end;

handle_cast(Info, S) ->
    ?LOG_ERROR("unknown ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, S}.


handle_info({'DOWN', _MonRef, process, Pid, Reason},
            #state{cfg = Cfg, registry = R0} = S) ->
    ?LOG_INFO("eNB process ~p terminated with reason ~p", [Pid, Reason]),
    case dict:find(Pid, R0) of
        {ok, ES} ->
            enb_del_nft_rules(ES, Cfg),
            R1 = dict:erase(Pid, R0),
            ?LOG_INFO("eNB (pid ~p) has been unregistered", [Pid]),
            {noreply, S#state{registry = R1}};
        error ->
            ?LOG_ERROR("eNB (pid ~p) is *not* registered", [Pid]),
            {noreply, S}
    end;

handle_info(Info, S) ->
    ?LOG_ERROR("unknown ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, S}.


terminate(Reason,
          #state{cfg = Cfg}) ->
    ?LOG_NOTICE("Terminating, reason ~p", [Reason]),
    case Cfg of
        #{enable := true, table_name := TName} ->
            nft_flush_table(TName),
            ok;
        _ -> ok %% stub mode
    end.


%% ------------------------------------------------------------------
%% private API
%% ------------------------------------------------------------------

enb_add_metrics(GlobalENBId) ->
    enb_add_metrics(GlobalENBId, ul),
    enb_add_metrics(GlobalENBId, dl).

enb_add_metrics(GlobalENBId, ULDL) ->
    %% counters may already exist, so catch exceptions here
    catch exometer:new(?S1GW_CTR_GTPU_PACKETS(GlobalENBId, ULDL), counter),
    catch exometer:new(?S1GW_CTR_GTPU_BYTES_UE(GlobalENBId, ULDL), counter),
    catch exometer:new(?S1GW_CTR_GTPU_BYTES_TOTAL(GlobalENBId, ULDL), counter).


-spec enb_set_addr(ULDLAddr, ES, S0) -> {Reply, S1}
    when ULDLAddr :: uldl_addr(),
         ES :: enb_state(),
         S0 :: #state{},
         S1 :: #state{},
         Reply :: ok | {error, term()}.
enb_set_addr({ULDL, Addr},
             #{genb_id := GlobalENBId,
               pid := Pid} = ES0,
             #state{cfg = Cfg, registry = R0} = S) ->
    case enb_add_nft_counter({ULDL, Addr}, GlobalENBId, Cfg) of
        {ok, Handle} ->
            ?LOG_INFO("eNB (pid ~p, ~p): NFT rules/counters created for ~p",
                      [Pid, GlobalENBId, ULDL]),
            %% store an updated eNB state to the registry
            ES1 = ES0#{ULDL => #{uldl => ULDL,
                                 addr => Addr,
                                 handle => Handle,
                                 ctr => #ctr{}}},
            R1 = dict:store(Pid, ES1, R0),
            {ok, S#state{registry = R1}};
        {error, Error} ->
            ?LOG_ERROR("eNB (pid ~p, ~p): creating NFT rules/counters failed: ~p",
                       [Pid, GlobalENBId, Error]),
            {{error, Error}, S}
    end.


enb_nft_counter_name(ul, GlobalENBId) -> "ul-" ++ GlobalENBId;
enb_nft_counter_name(dl, GlobalENBId) -> "dl-" ++ GlobalENBId.

enb_nft_chain_name(ul) -> "gtpu-ul";
enb_nft_chain_name(dl) -> "gtpu-dl".


-spec enb_add_nft_counter(ULDLAddr, GlobalENBId, Cfg) -> {ok, Handle} | {error, term()}
    when ULDLAddr :: uldl_addr(),
         GlobalENBId :: string(),
         Cfg :: cfg(),
         Handle :: integer().
enb_add_nft_counter({ULDL, Addr}, GlobalENBId,
                    #{table_name := TName}) ->
    CtrName = enb_nft_counter_name(ULDL, GlobalENBId),
    CName = enb_nft_chain_name(ULDL),
    Counter = enftables:nft_counter(TName, CtrName),
    Rule = [nft_expr_match_ip_addr({ULDL, Addr}),
            enftables:nft_expr_counter(CtrName)],
    Cmds = [enftables:nft_cmd_add_counter(Counter),
            enftables:nft_cmd_add_rule(TName, CName, Rule)
           ],
    case nft_exec(Cmds) of
        ok ->
            Handle = nft_chain_last_handle(TName, CName),
            {ok, Handle};
        {error, Error} ->
            ?LOG_ERROR("~p() failed: ~p", [?FUNCTION_NAME, Error]),
            {error, Error}
    end.


-spec enb_del_nft_rules(ES0, Cfg) -> ES1
    when Cfg :: cfg(),
         ES0 :: enb_state(),
         ES1 :: enb_state().
enb_del_nft_rules(#{ul := ULS} = ES, Cfg) ->
    enb_del_nft_rule(ULS, Cfg),
    enb_del_nft_rules(maps:remove(ul, ES), Cfg);

enb_del_nft_rules(#{dl := DLS} = ES, Cfg) ->
    enb_del_nft_rule(DLS, Cfg),
    enb_del_nft_rules(maps:remove(dl, ES), Cfg);

%% no UL/DL state => nothing to delete
enb_del_nft_rules(ES, _Cfg) -> ES.


-spec enb_del_nft_rule(S, Cfg) -> ok | {error, term()}
    when S :: enb_uldl_state(),
         Cfg :: cfg().
enb_del_nft_rule(#{uldl := ULDL, handle := Handle},
                 #{table_name := TName}) ->
    %% keep the counter intact, remove the rule only
    %% we don't want to reset counters if eNB reconnects
    CName = enb_nft_chain_name(ULDL),
    Cmds = [enftables:nft_cmd_del_rule(TName, CName, Handle)],
    case nft_exec(Cmds) of
        ok -> ok;
        {error, Error} ->
            ?LOG_ERROR("~p() failed: ~p", [?FUNCTION_NAME, Error]),
            {error, Error}
    end.


%% Parse the given list of NFT counters (result of nft_cmd_list_counters()).
-spec parse_nft_counters([map()]) -> counters().
parse_nft_counters(L0) ->
    L1 = lists:map(fun parse_nft_counter/1, L0),
    dict:from_list(L1).


-spec parse_nft_counter(map()) -> {string(), counter()}.
parse_nft_counter(#{<< "counter" >> := #{<< "name" >> := Name,
                                         << "bytes" >> := Bytes,
                                         << "packets" >> := Packets}}) ->
    %% Assuming an IP header of 20 bytes, derive the GTP-U payload size:
    %%  [...]             \              \
    %%  [ UDP ][ TCP ]    | UE payload   | nft reports these bytes
    %%  [ IP ]            /              |
    %%  -- payload --                    |
    %%  [ GTP-U 8 bytes ]                |   \
    %%  [ UDP 8 bytes ]                  |   | need to subtract these, 20 + 8 + 8
    %%  [ IP 20 bytes ]                  /   /
    Ctr = #ctr{packets = Packets,
               bytes_ue = Bytes - erlang:min(Bytes, Packets * (20 + 8 + 8)),
               bytes_total = Bytes},
    {binary_to_list(Name), Ctr}.


-spec report_nft_counters(Ctrs, R0) -> R1
    when Ctrs :: counters(),    %% result of parse_nft_counters()
         R0 :: registry(),      %% (current) registry of eNBs
         R1 :: registry().      %% (new) registry of eNBs
report_nft_counters(Ctrs, R0) ->
    %% for each registered eNB, look-up and report UL/DL counters
    dict:map(fun(_Pid, ES) -> enb_report_nft_counters(Ctrs, ES) end, R0).


-spec enb_report_nft_counters(counters(), enb_state()) -> enb_state().
enb_report_nft_counters(Ctrs, #{ul := ULS,
                                dl := DLS} = ES) ->
    %% report both UL and DL counters
    ES#{ul => enb_report_nft_counters(Ctrs, ES, ULS),
        dl => enb_report_nft_counters(Ctrs, ES, DLS)};

enb_report_nft_counters(Ctrs, #{ul := ULS} = ES) ->
    %% report UL counters only
    ES#{ul => enb_report_nft_counters(Ctrs, ES, ULS)};

enb_report_nft_counters(Ctrs, #{dl := DLS} = ES) ->
    %% report DL counters only
    ES#{dl => enb_report_nft_counters(Ctrs, ES, DLS)};

%% no UL/DL state => nothing to report
enb_report_nft_counters(_Ctrs, ES) -> ES.


-spec enb_report_nft_counters(Ctrs, ES, S0) -> S1
    when Ctrs :: counters(),
         ES :: enb_state(),
         S0 :: enb_uldl_state(),
         S1 :: enb_uldl_state().
enb_report_nft_counters(Ctrs,
                        #{genb_id := GlobalENBId, pid := Pid},
                        #{uldl := ULDL, ctr := C0} = S) ->
    CtrName = enb_nft_counter_name(ULDL, GlobalENBId),
    case dict:find(CtrName, Ctrs) of
        {ok, C0} ->
            %% no diff, nothing to report
            ?LOG_DEBUG("NFT counters (~p) for eNB (pid ~p, ~p): ~p",
                       [ULDL, Pid, GlobalENBId, C0]),
            S;
        {ok, C1} ->
            %% XXX: assuming C1 (new) values >= C0 (cached) values
            ?LOG_DEBUG("NFT counters (~p) for eNB (pid ~p, ~p): ~p -> ~p",
                       [ULDL, Pid, GlobalENBId, C0, C1]),
            s1gw_metrics:ctr_inc(?S1GW_CTR_GTPU_PACKETS(GlobalENBId, ULDL),
                                 C1#ctr.packets - C0#ctr.packets),
            s1gw_metrics:ctr_inc(?S1GW_CTR_GTPU_BYTES_UE(GlobalENBId, ULDL),
                                 C1#ctr.bytes_ue - C0#ctr.bytes_ue),
            s1gw_metrics:ctr_inc(?S1GW_CTR_GTPU_BYTES_TOTAL(GlobalENBId, ULDL),
                                 C1#ctr.bytes_total - C0#ctr.bytes_total),
            S#{ctr => C1};
        error ->
            %% no counters for this eNB (yet?)
            ?LOG_DEBUG("NFT counters (~p) for eNB (pid ~p, ~p): nope",
                       [ULDL, Pid, GlobalENBId]),
            S
    end.


-spec nft_flush_table(string()) -> enftables:result().
nft_flush_table(TName) ->
    Cmds = [enftables:nft_cmd_del_table(TName)],
    nft_exec(Cmds).


-spec nft_init_table(string()) -> enftables:result().
nft_init_table(TName) ->
    %% ignore (accept) anything but GTPU @ udp/2152
    R1 = [enftables:nft_expr_match_ip_proto("udp", ?OP_NEQ),
          enftables:nft_expr_accept()],
    R2 = [enftables:nft_expr_match_udp_dport(?GTPU_PORT, ?OP_NEQ),
          enftables:nft_expr_accept()],
    Cmds = [enftables:nft_cmd_add_table(TName, [<< "owner" >>]),
            nft_cmd_add_chain(TName, "gtpu-ul", "prerouting"),
            nft_cmd_add_chain(TName, "gtpu-dl", "postrouting"),
            enftables:nft_cmd_add_rule(TName, "gtpu-ul", R1),
            enftables:nft_cmd_add_rule(TName, "gtpu-dl", R1),
            enftables:nft_cmd_add_rule(TName, "gtpu-ul", R2),
            enftables:nft_cmd_add_rule(TName, "gtpu-dl", R2)
           ],
    nft_exec(Cmds).


-spec nft_exec(Cmds) -> enftables:result()
    when Cmds :: [enftables:nft_cmd()].
nft_exec(Cmds) ->
    ?LOG_DEBUG("Executing nftables commands: ~p", [Cmds]),
    enftables:run_cmd(Cmds, json).


nft_chain_last_handle(TName, CName) ->
    Cmds = [enftables:nft_cmd_list_chain(TName, CName)],
    case nft_exec(Cmds) of
        {ok, Res} ->
            #{<< "rule" >> := Rule} = lists:last(Res),
            maps:get(<< "handle" >>, Rule);
        Error ->
            ?LOG_ERROR("~p() failed: ~p", [?FUNCTION_NAME, Error]),
            error
    end.


nft_cmd_add_chain(TName, CName, Hook) ->
    C = #{family => << "inet" >>,
          table => list_to_binary(TName),
          name => list_to_binary(CName),
          type => << "filter" >>,
          hook => list_to_binary(Hook),
          prio => 0,
          policy => << "accept" >>
         },
    #{add => #{chain => C}}.


-spec nft_expr_match_ip_addr(uldl_addr()) -> map().
nft_expr_match_ip_addr({ul, Addr}) ->
    enftables:nft_expr_match_ip_saddr(Addr, ?OP_EQ);

nft_expr_match_ip_addr({dl, Addr}) ->
    enftables:nft_expr_match_ip_daddr(Addr, ?OP_EQ).


-spec heartbeat(timeout()) -> no_return().
heartbeat(Tval) ->
    timer:sleep(Tval),
    gen_server:cast(?MODULE, report_nft_counters),
    heartbeat(Tval). %% keep going


%% vim:set ts=4 sw=4 et:
