# Copyright (c) 2017-2026 Juancarlo Añez (apalala@gmail.com)
# SPDX-License-Identifier: BSD-4-Clause
from __future__ import annotations

import re
from collections import defaultdict
from collections.abc import Iterable
from graphlib import TopologicalSorter
from typing import Any

from .undefined import Undefined


class CycleError(ValueError):
    pass


def first(iterable: Iterable[Any], default: Any = Undefined) -> Any:
    """Return the first item of *iterable*, or *default* if *iterable* is
    empty.

        >>> first([0, 1, 2, 3])
        0
        >>> first([], 'some default')
        'some default'

    If *default* is not provided and there are no items in the iterable,
    raise ``ValueError``.

    :func:`first` is useful when you have a generator of expensive-to-retrieve
    values and want any arbitrary one. It is marginally shorter than
    ``next(iter(iterable), default)``.

    """
    # NOTE: https://more-itertools.readthedocs.io/en/stable/_modules/more_itertools/more.html#first
    try:
        return next(iter(iterable))
    except StopIteration as e:
        # I'm on the edge about raising ValueError instead of StopIteration. At
        # the moment, ValueError wins, because the caller could conceivably
        # want to do something different with flow control when I raise the
        # exception, and it's weird to explicitly catch StopIteration.
        if default is Undefined:
            raise ValueError(
                'first() was called on an empty iterable, and no '
                'default value was provided.',
            ) from e
        return default


def str_from_match(match: re.Match) -> str | None:
    g = find_from_match(match)
    if isinstance(g, tuple):
        return g[0] if g else None
    else:
        return g


def find_from_match(m: re.Match) -> str | tuple[str, ...] | None:
    if m is None:
        return None
    g = m.groups(default=m.string[0:0])
    if len(g) == 1:
        return g[0]
    else:
        return g or m.group()


def iter_findall(
    pattern: str | re.Pattern,
    string: str,
    pos: int | None = None,
    endpos: int | None = None,
    flags: int = 0,
) -> Iterable[str]:
    """
    like finditer(), but with return values like findall()

    implementation taken from cpython/Modules/_sre.c/findall()
    """
    r = pattern if isinstance(pattern, re.Pattern) else re.compile(pattern, flags=flags)
    if pos is not None and endpos is not None:
        iterator = r.finditer(
            string,
            pos=pos,
            endpos=endpos,
        )  # pyright: ignore[reportCallIssue]
    elif pos is not None:
        iterator = r.finditer(string, pos=pos)
    else:
        iterator = r.finditer(string)
    return tuple(s for s in (str_from_match(m) for m in iterator if m) if s is not None)


def findfirst(
    pattern,
    string,
    pos=None,
    endpos=None,
    flags=0,
    default=Undefined,
) -> str:
    """
    Avoids using the inefficient findall(...)[0], or first(findall(...))
    """
    return first(
        iter_findall(pattern, string, pos=pos, endpos=endpos, flags=flags),
        default=default,
    )


def topsort[T](nodes: Iterable[T], edges: Iterable[tuple[T, T]]) -> list[T]:
    # https://en.wikipedia.org/wiki/Topological_sorting

    # NOTE:
    #   topsort uses a partial order relationship,
    #   so results for the same arguments may be
    #   different from one call to the other
    #   _
    #   use this to make results stable accross calls
    #       topsort(n, e) == topsort(n, e)

    nodes = list(nodes)
    original_key = {node: i for i, node in enumerate(nodes)}

    def order_key(node: T) -> int:
        return original_key[node]

    partial_order = set(edges)

    def with_incoming() -> set[T]:
        nonlocal partial_order
        return {m for (_, m) in partial_order}

    result: list[T] = []
    pending = sorted(set(nodes) - with_incoming(), key=order_key, reverse=True)
    while pending:
        n = pending.pop()
        result.append(n)

        outgoing = {m for (x, m) in partial_order if x == n}
        partial_order -= {(n, m) for m in outgoing}
        pending.extend(outgoing - with_incoming())
        pending.sort(key=order_key, reverse=True)

    if partial_order:
        raise CycleError(f'There are cycles in {partial_order=!r}')

    return list(result)


def graphlib_topsort[T](nodes: Iterable[T], edges: Iterable[tuple[T, T]]) -> list[T]:
    graph: dict[T, list[T]] = defaultdict(list[T], {n: [] for n in nodes})
    for n, m in edges:
        graph[m].append(n)

    sorter = TopologicalSorter(graph)
    return list(sorter.static_order())
