import { mapSet } from './iterator'

export function difference<T>(a: Set<T>, b: Set<T>): Set<T> {
  const inverse = new Set<T>()

  for (const item of new Set([...a, ...b])) {
    if (!a.has(item) || !b.has(item)) {
      inverse.add(item)
    }
  }
  return inverse
}

export function intersection<T>(s: Set<T>[]): Set<T> {
  const total = s.length
  if (total === 0) return new Set([])
  if (total === 1) return s[0] ?? new Set()
  s.sort((a, b) => a.size - b.size)
  const [a, ...b] = s
  if (!a) return new Set()
  return new Set([...a].filter(x => b.every(bs => bs?.has(x))))
}

export function intersectionBy<X>(getY: (x: X) => unknown, xss: Set<X>[]): Set<X> {
  const xToY = new Map(xss.flat().flatMap(x => [...x].map(xs => [xs, getY(xs)])))
  const yToX = new Map([...xToY.entries()].map(([x, y]) => [y, x]))

  const yss = xss.map(xs =>
    mapSet(x => {
      if (!xToY.has(x)) {
        throw new Error(`intersectBy: cannot find x ${x}; this should never happen?`)
      }
      return xToY.get(x)
    }, xs),
  )

  return mapSet(y => {
    if (!yToX.has(y)) {
      throw new Error(`intersectBy: cannot find y ${y}; this should never happen?`)
    }
    return yToX.get(y) as X
  }, intersection(yss))
}

export function union<T>(sets: Set<T>[]): Set<T> {
  return new Set(sets.flatMap(set => [...set]))
}
