//! Aggregate query functions (includes PromQL aggregation operators)
use super::{basic_fn, qry_fn, QryFunc};
use crate::{
    query::{
        ops::{
            modifiers::{AggrType, Mod},
            Operable,
        },
        IntoQuery,
    },
    seal::Sealed,
};
use core::fmt;
use std::fmt::Display;

/// Type that represents an aggregate function, also known as the aggregation operator in PromQL
///
/// Unlike other functions, this provides a modifier [`by`](Self::by) and [`without`](Self::without). Only one can be used at once.
#[derive(Debug, Clone)]
pub struct AggrFunc<'a, F: Fn(&mut fmt::Formatter) -> fmt::Result> {
    inner: QryFunc<F>,
    mod_type: Mod<'a, AggrType>,
}

impl<'a, F: Fn(&mut fmt::Formatter) -> fmt::Result> AggrFunc<'a, F> {
    pub fn by<I: IntoIterator<Item = &'a str>>(mut self, labels: I) -> Self {
        self.mod_type = Mod::from((AggrType::By, labels));
        self
    }

    pub fn without<I: IntoIterator<Item = &'a str>>(mut self, labels: I) -> Self {
        self.mod_type = Mod::from((AggrType::Without, labels));
        self
    }
}

impl<F: Fn(&mut fmt::Formatter) -> fmt::Result> From<QryFunc<F>> for AggrFunc<'_, F> {
    fn from(value: QryFunc<F>) -> Self {
        AggrFunc {
            inner: value,
            mod_type: Default::default(),
        }
    }
}

impl<F: Fn(&mut fmt::Formatter) -> fmt::Result> Sealed for AggrFunc<'_, F> {}
impl<F: Fn(&mut fmt::Formatter) -> fmt::Result> Operable for AggrFunc<'_, F> {}

impl<F: Fn(&mut fmt::Formatter) -> fmt::Result> Display for AggrFunc<'_, F> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{qry_fn}{mod_type}",
            qry_fn = self.inner,
            mod_type = self.mod_type
        )
    }
}

impl<F: Fn(&mut fmt::Formatter) -> fmt::Result> IntoQuery for AggrFunc<'_, F> {
    type Target = String;
    fn into_query(self) -> Self::Target {
        self.to_string()
    }
}

/// The sum aggregate operator/function
#[inline]
pub fn sum<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("sum", vec_expr).into()
}

/// The min aggregate operator/function
#[inline]
pub fn min<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("min", vec_expr).into()
}

/// The max aggregate operator/function
#[inline]
pub fn max<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("max", vec_expr).into()
}

/// The avg aggregate operator/function
#[inline]
pub fn avg<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("avg", vec_expr).into()
}

/// The group aggregate operator/function
#[inline]
pub fn group<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("group", vec_expr).into()
}

/// The stddev aggregate operator/function
#[inline]
pub fn stddev<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("stddev", vec_expr).into()
}

/// The stdvar aggregate operator/function
#[inline]
pub fn stdvar<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("stdvar", vec_expr).into()
}

/// The count aggregate operator/function
#[inline]
pub fn count<'a>(
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    basic_fn("count", vec_expr).into()
}

/// The count_values aggregate operator/function
#[inline]
pub fn count_values<'a>(
    label: &'a str,
    vec_expr: impl Operable + 'a,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result + 'a> {
    qry_fn!(count_values, r#""{label}", {vec_expr}"#).into()
}

/// The topk aggregate operator/function
#[inline]
pub fn topk<'a>(
    k: usize,
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    qry_fn!(topk, "{k}, {vec_expr}").into()
}

/// The bottomk aggregate operator/function
#[inline]
pub fn bottomk<'a>(
    k: usize,
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    qry_fn!(bottomk, "{k}, {vec_expr}").into()
}

/// The bottomk aggregate operator/function (phi value should always be >= 0 and <= 1)
#[inline]
pub fn quantile<'a>(
    phi: f32,
    vec_expr: impl Operable,
) -> AggrFunc<'a, impl Fn(&mut fmt::Formatter) -> fmt::Result> {
    qry_fn!(quantile, "{phi}, {vec_expr}").into()
}

#[cfg(test)]
mod tests {
    use crate::query::Metric;

    use super::*;

    #[test]
    fn by_mod() {
        let agr: AggrFunc<_> = qry_fn!(func, "x").into();
        let by_str = agr.by(["label", "label2"]).to_string();
        assert_eq!(by_str, "func(x) by (label,label2)");
    }

    #[test]
    fn without_mod() {
        let agr: AggrFunc<_> = qry_fn!(func, "x").into();
        let by_str = agr.without(["label", "label2"]).to_string();
        assert_eq!(by_str, "func(x) without (label,label2)");
    }

    #[test]
    fn aggr_sum() {
        let vec = Metric::new("test_metric");
        let sum = sum(vec).to_string();
        assert_eq!(sum, "sum(test_metric)");
    }

    #[test]
    fn aggr_min() {
        let vec = Metric::new("test_metric");
        let min = min(vec).to_string();
        assert_eq!(min, "min(test_metric)");
    }

    #[test]
    fn aggr_max() {
        let vec = Metric::new("test_metric");
        let max = max(vec).to_string();
        assert_eq!(max, "max(test_metric)");
    }

    #[test]
    fn aggr_avg() {
        let vec = Metric::new("test_metric");
        let avg = avg(vec).to_string();
        assert_eq!(avg, "avg(test_metric)");
    }

    #[test]
    fn aggr_group() {
        let vec = Metric::new("test_metric");
        let group = group(vec).to_string();
        assert_eq!(group, "group(test_metric)");
    }

    #[test]
    fn aggr_stddev() {
        let vec = Metric::new("test_metric");
        let stddev = stddev(vec).to_string();
        assert_eq!(stddev, "stddev(test_metric)");
    }

    #[test]
    fn aggr_stdvar() {
        let vec = Metric::new("test_metric");
        let stdvar = stdvar(vec).to_string();
        assert_eq!(stdvar, "stdvar(test_metric)");
    }

    #[test]
    fn aggr_count() {
        let vec = Metric::new("test_metric");
        let count = count(vec).to_string();
        assert_eq!(count, "count(test_metric)");
    }

    #[test]
    fn aggr_count_values() {
        let vec = Metric::new("build_version");
        let count_values = count_values("version", vec).to_string();
        assert_eq!(count_values, "count_values(\"version\", build_version)");
    }

    #[test]
    fn aggr_topk() {
        let vec = Metric::new("http_requests_total");
        let topk = topk(5, vec).to_string();
        assert_eq!(topk, "topk(5, http_requests_total)");
    }

    #[test]
    fn aggr_bottomk() {
        let vec = Metric::new("http_requests_total");
        let bottomk = bottomk(5, vec).to_string();
        assert_eq!(bottomk, "bottomk(5, http_requests_total)");
    }
    #[test]
    fn aggr_quantile() {
        let vec = Metric::new("http_requests_total");
        let quantile = quantile(0.1, vec).to_string();
        assert_eq!(quantile, "quantile(0.1, http_requests_total)");
    }
}