import React from 'react';

function isChanged(last: any[] | undefined | null, next: any[]) {
  if (!last) {
    return true;
  }
  const nextLength = next.length;
  let i = 0;
  for (; i < nextLength; i++) {
    if (next[i] !== last[i]) {
      return true;
    }
  }
  return false;
}

export type TWithStateSelf<P, S> = {
  props: P,
  state: S,
  setState: (state: Partial<S>) => void,
  localMemo: <A>(cb: (...args: any[]) => A, values: any[]) => A,
  localCallback: <A, B>(cb: (arg: A) => B, values: any[]) => ((arg: A) => B),
  localEffect: (cb: (...args: any[]) => any, values: any[]) => void,
};

export type TWithStateRender<P, S> = (props: P, state: S) => any;

export type TWithStateConstructor<P, S> = (setState: (state: Partial<S>) => void, self: TWithStateSelf<P, S>) => TWithStateRender<P, S>;


export function withState<P, S>(constructor: TWithStateConstructor<P, S>) {
  class WithStateComponent extends React.Component<P, S> {
    constructor(props: P) {
      super(props);
  
      const self: any = this;
      const _setState = self.setState;

      let _mounted = false;

      let _memoIndex = 0;
      const _memoArgs: any[] = [];
      const _memoResults: any[] = [];

      function setState(state: Partial<S>) {
        _mounted && _setState.call(self, state);
      }
  
      function localMemo<A>(cb: (...args: any[]) => A, values: any[]): A {
        const index = _memoIndex;
        _memoIndex++;
        if (isChanged(_memoArgs[index], values)) {
          _memoArgs[index] = values;
          _memoResults[index] = cb();
        }
        return _memoResults[index];
      }
  
      function localCallback<A, B>(cb: (arg: A) => B, values: any[]): ((arg: A) => B) {
        return localMemo(() => cb, values);
      }
  
      let _effectIndex = 0;
      const _effectArgs: any[] = [];
      let _effectCancels: any[] = [];
      const _callbacks: (() => void)[] = [];

      function setCallback(cb: (...args: any[]) => any, index: number) {
        let cancel = _effectCancels[index];
        cancel && cancel();
        cancel = cb();
        _effectCancels[index] = typeof cancel == 'function'
          ? cancel
          : null;
      }
      
      function localEffect(cb: (...args: any[]) => any, values: any[]) {
        _callbacks[_effectIndex] = cb;
        if (isChanged(_effectArgs[_effectIndex], values)) {
          _effectArgs[_effectIndex] = values;
          setCallback(cb, _effectIndex);
        }
        _effectIndex++;
      }
  
      self.setState = setState;
      self.localMemo = localMemo;
      self.localCallback = localCallback;
      self.localEffect = localEffect;

      self.componentDidMount = () => {
        if (_mounted) {
          return;
        }
        _mounted = true;
        _callbacks.forEach(setCallback);
      };
  
      self.componentWillUnmount = () => {
        if (!_mounted) {
          return;
        }
        _mounted = false;
        const effectCancels = _effectCancels;
        _effectCancels = [];
        effectCancels.forEach((cb) => {
          cb && cb();
        });
      };

      const render = constructor(setState, self as TWithStateSelf<P, S>);
  
      self.render = () => {
        _effectIndex = _memoIndex = 0;
        return render(self.props, self.state);
      };
    }
  }

  return WithStateComponent;
}

