import { exhaustiveTypeException } from './type-utils';

export type ForAwaitableIterable<T> = AsyncIterable<T> | Iterable<T>;

export type NotAPromise<T> = T extends Promise<unknown> ? never : T;

export async function* from<T>(iterable: Iterable<T>): AsyncIterable<T> {
  for (const val of iterable) {
    yield val;
  }
}

export async function* fromPromise<T>(promise: Promise<T>): AsyncIterable<T> {
  yield await promise;
}

export async function* map<T, U>(
  asyncIterable: ForAwaitableIterable<T>,
  callback: (val: T, index: number) => NotAPromise<U>
): AsyncGenerator<NotAPromise<U>> {
  let i = 0;
  for await (const val of asyncIterable) {
    yield callback(val, i++);
  }
}

export async function* mapAndDiscardNulls<T, U>(
  asyncIterable: ForAwaitableIterable<T>,
  callback: (val: T, index: number) => NotAPromise<U> | null
): AsyncGenerator<NotAPromise<U>> {
  let i = 0;
  for await (const val of asyncIterable) {
    const result = callback(val, i++);
    if (result !== null) {
      yield result;
    }
  }
}

const isIteratorResult = <T>(val: unknown): val is IteratorResult<T> =>
  val !== null && typeof val === 'object' && 'done' in val && 'value' in val;

// Map the given async callback over the async iterable in a maximally parallel fashion.
// Each element from asyncIterable is immediately passed to the callback, and the resulting promise is awaited in
// parallel to pulling new items from asyncIterable.
export async function* mapInParallel<T, U>(
  asyncIterable: AsyncIterable<T>,
  callback: (val: T, index: number) => Promise<U>
): AsyncIterable<U> {
  const promises: Promise<U>[] = [];
  let i = 0;
  const iterator = asyncIterable[Symbol.asyncIterator]();
  let iteratorNextPromise: Promise<IteratorResult<T>> | null = iterator.next();
  while (iteratorNextPromise || promises.length) {
    const nextResultPromise = promises.length > 0 ? promises[0] : undefined;
    const doneResult = await Promise.race([
      ...(iteratorNextPromise ? [iteratorNextPromise] : []),
      ...(nextResultPromise ? [nextResultPromise] : []),
    ]);
    if (isIteratorResult(doneResult)) {
      if (doneResult.done) {
        iteratorNextPromise = null;
        continue;
      }
      promises.push(callback(doneResult.value, i++));
      iteratorNextPromise = iterator.next();
      continue;
    }
    promises.shift();
    yield doneResult;
  }
}

export async function* chunk<T>(asyncIterable: ForAwaitableIterable<T>, chunkSize: number): AsyncGenerator<T[]> {
  let chunk: T[] = [];
  for await (const val of asyncIterable) {
    chunk.push(val);
    if (chunk.length >= chunkSize) {
      yield chunk;
      chunk = [];
    }
  }
  if (chunk.length > 0) {
    yield chunk;
  }
}

export async function* flatten<T>(asyncIterable: ForAwaitableIterable<ForAwaitableIterable<T>>): AsyncGenerator<T> {
  for await (const iterable of asyncIterable) {
    for await (const val of iterable) {
      yield val;
    }
  }
}

export function flatMap<T, U>(
  asyncIterable: ForAwaitableIterable<T>,
  callback: (val: T, index: number) => ForAwaitableIterable<U>
): AsyncGenerator<U> {
  return flatten(map(asyncIterable, callback));
}

export function filter<T, U extends T>(
  iterable: ForAwaitableIterable<T>,
  filter: (val: T, index: number) => val is U
): AsyncGenerator<U>;
export function filter<T>(
  iterable: ForAwaitableIterable<T>,
  filter: (val: T, index: number) => boolean
): AsyncGenerator<T>;
export async function* filter<T>(
  asyncIterable: ForAwaitableIterable<T>,
  filter: (val: T, index: number) => boolean
): AsyncGenerator<T> {
  let i = 0;
  for await (const val of asyncIterable) {
    if (filter(val, i++)) {
      yield val;
    }
  }
}

export const filterInParallel = <T>(
  asyncIterable: AsyncIterable<T>,
  asyncFilter: (val: T, index: number) => Promise<boolean>
) =>
  flatten(
    mapInParallel(asyncIterable, async (val, index) => {
      if (await asyncFilter(val, index)) {
        return [val];
      }
      return [];
    })
  );

export async function some<T>(
  asyncIterable: ForAwaitableIterable<T>,
  condition: (val: T, index: number) => boolean
): Promise<boolean> {
  let i = 0;
  for await (const val of asyncIterable) {
    if (condition(val, i++)) {
      return true;
    }
  }
  return false;
}

export async function array<T>(asyncIterable: ForAwaitableIterable<T>): Promise<T[]> {
  const arr = [];
  for await (const val of asyncIterable) {
    arr.push(val);
  }
  return arr;
}

export async function first<T>(asyncIterable: ForAwaitableIterable<T>): Promise<T | undefined> {
  for await (const val of asyncIterable) {
    return val;
  }
  return undefined;
}

type ElementType<T> = T extends AsyncIterable<infer U> ? U : never;

export async function* combine<T extends AsyncIterable<unknown>>(
  asyncIterables: Iterable<T>
): AsyncIterable<ElementType<T>> {
  const sink = new Sink<ElementType<T>>();
  let count = 0;
  for (const asyncIterable of asyncIterables) {
    (async () => {
      for await (const val of asyncIterable) {
        sink.send(val as ElementType<T>);
      }
    })()
      .then(() => {
        count--;
        if (count === 0) {
          sink.done();
        }
      })
      .catch((e) => {
        sink.error(e);
      });
    count++;
  }
  if (count === 0) {
    return;
  }
  yield* sink.iterate();
}

type CallMultipleReturnType<T> = {
  [K in keyof T]: { funcName: K; result: Awaited<T[K]> };
};

/** Await multiple promises, and yield the results of each one as they are available. The iterable ends when all the
 * promises have been resolved. */
export async function* awaitPromises<T extends Record<string, Promise<unknown>>>(
  promises: T
): AsyncGenerator<CallMultipleReturnType<T>[keyof T]> {
  for await (const { funcName, result } of combine(
    Object.entries(promises).map(([funcName, promise]) =>
      (async function* () {
        yield { funcName, result: await promise };
      })()
    )
  )) {
    yield { funcName, result: result as CallMultipleReturnType<T>[keyof T]['result'] };
  }
}

export async function* chain<T extends Array<AsyncIterable<unknown>>>(
  ...asyncIterables: [...T]
): AsyncGenerator<ElementType<T[number]>> {
  for (const iter of asyncIterables) {
    for await (const value of iter) {
      yield value as ElementType<T[number]>;
    }
  }
}

export async function reduce<T, U>(
  iterable: ForAwaitableIterable<T>,
  initial: NotAPromise<U>,
  func: (acccumulator: U, val: T, index: number) => NotAPromise<U>
): Promise<NotAPromise<U>> {
  let i = 0;
  let accumulator = initial;
  for await (const val of iterable) {
    accumulator = func(accumulator, val, i++);
  }
  return accumulator;
}

export async function countBy<T, U>(
  values: ForAwaitableIterable<T>,
  keySelector: (t: T) => NotAPromise<U>
): Promise<Map<NotAPromise<U>, number>> {
  return countByMultiple(values, (t) => [keySelector(t)]);
}

export async function countByMultiple<T, U>(
  values: ForAwaitableIterable<T>,
  multiKeySelector: (t: T) => Iterable<U>
): Promise<Map<U, number>> {
  return reduce(values, new Map<U, number>(), (map, value) => {
    for (const key of multiKeySelector(value)) {
      map.set(key, (map.get(key) ?? 0) + 1);
    }
    return map;
  });
}

export async function mapBy<T, U>(
  values: ForAwaitableIterable<T>,
  keySelector: (t: T) => NotAPromise<U>
): Promise<Map<U, T>> {
  return reduce(values, new Map<U, T>(), (map, value) => map.set(keySelector(value), value));
}

export async function groupBy<T, U>(
  values: ForAwaitableIterable<T>,
  keySelector: (t: T) => NotAPromise<U>
): Promise<Map<U, T[]>> {
  return groupByMultiple(values, (t) => [keySelector(t)]);
}

export async function groupByMultiple<T, U>(
  values: ForAwaitableIterable<T>,
  multiKeySelector: (t: T) => Iterable<U>
): Promise<Map<NotAPromise<U>, T[]>> {
  return groupByMultipleAndTransform(values, multiKeySelector, (t) => t as NotAPromise<T>);
}

export async function groupByAndTransform<T, T2, U>(
  values: ForAwaitableIterable<T>,
  keySelector: (t: T) => NotAPromise<U>,
  transformer: (t: T) => NotAPromise<T2>
): Promise<Map<NotAPromise<U>, NotAPromise<T2>[]>> {
  return groupByMultipleAndTransform(values, (t) => [keySelector(t)], transformer);
}

export async function groupByMultipleAndTransform<T, T2, U>(
  values: ForAwaitableIterable<T>,
  multiKeySelector: (t: T) => Iterable<U>,
  transformer: (t: T) => NotAPromise<T2>
): Promise<Map<NotAPromise<U>, NotAPromise<T2>[]>> {
  const groups: Map<NotAPromise<U>, NotAPromise<T2>[]> = new Map();
  for await (const value of values) {
    for (const key of multiKeySelector(value)) {
      let listForKey = groups.get(key as NotAPromise<U>);
      if (listForKey === undefined) {
        groups.set(key as NotAPromise<U>, (listForKey = []));
      }
      listForKey.push(transformer(value));
    }
  }
  return groups;
}

export async function* take<T>(iterable: ForAwaitableIterable<T>, count: number): AsyncGenerator<T> {
  let i = 0;
  for await (const val of iterable) {
    if (i++ >= count) {
      return;
    }
    yield val;
  }
}

export async function* enumerate<T>(iterable: ForAwaitableIterable<T>): ForAwaitableIterable<[T, number]> {
  let i = 0;
  for await (const val of iterable) {
    yield [val, i++];
  }
}

export async function join<T>(iterable: ForAwaitableIterable<T>, separator: string): Promise<string> {
  let result = '';
  for await (const item of iterable) {
    if (result.length > 0) {
      result += separator;
    }
    result += item;
  }
  return result;
}

type SinkEvent<T> = { type: 'next'; value: T } | { type: 'done' };

export class Sink<T> {
  private promises: Promise<SinkEvent<T>>[] = [];
  private resolveNext: (result: SinkEvent<T>) => void;
  private rejectNext: (error: Error) => void;

  constructor() {
    this.addNewPromise();
  }

  addNewPromise() {
    this.promises.push(
      new Promise<SinkEvent<T>>((resolve, reject) => {
        this.resolveNext = resolve;
        this.rejectNext = reject;
      })
    );
  }

  send(message: T) {
    this.resolveNext({ type: 'next', value: message });
    this.addNewPromise();
  }

  done() {
    this.resolveNext({ type: 'done' });
    this.addNewPromise();
  }

  error(e: Error) {
    this.rejectNext(e);
    this.addNewPromise();
  }

  async *iterate(): AsyncIterable<T> {
    iteration: while (true) {
      const nextEventPromise = this.promises.shift();
      if (!nextEventPromise) {
        throw new Error('Unexpected missing promise');
      }
      const result = await nextEventPromise;
      switch (result.type) {
        case 'next':
          yield result.value;
          continue;
        case 'done':
          break iteration;
        default:
          exhaustiveTypeException(result);
      }
    }
  }
}

export async function count<T>(iterable: ForAwaitableIterable<T>): Promise<number> {
  let count = 0;
  for await (const _ of iterable) {
    count++;
  }
  return count;
}

export async function exhaust<T>(iterable: ForAwaitableIterable<T>): Promise<void> {
  for await (const _ of iterable) {
    // Do nothing
  }
}

export async function* unique<T>(iterable: ForAwaitableIterable<T>): AsyncIterable<T> {
  const seen = new Set<T>();
  for await (const item of iterable) {
    if (!seen.has(item)) {
      seen.add(item);
      yield item;
    }
  }
}

export async function* uniqueBy<T, U>(
  iterable: ForAwaitableIterable<T>,
  keySelector: (t: T) => NotAPromise<U>
): AsyncIterable<T> {
  const seen = new Set<U>();
  for await (const item of iterable) {
    const key = keySelector(item);
    if (seen.has(key)) {
      continue;
    }
    seen.add(key);
    yield item;
  }
}
