import { useMutation, useQueryClient } from "react-query";
import API from "services";
import { IDataTableError } from "shared/types/errors";
import {
  IGenFeatureSessionState,
  IImageGenParams,
  ISessionResponse,
} from "shared/types/genAI";

const useGenerateImages = (
  args: IImageGenParams,
): {
  generateImages: () => void;
} => {
  const queryClient = useQueryClient();
  const { mutate } = useMutation<
    ISessionResponse<IGenFeatureSessionState> | IDataTableError | null,
    Error
  >(
    () =>
      API.services.genAI.generateImages(
        args.prompt,
        args.nSamples,
        args.sessionId,
      ),
    {
      onMutate: () => {
        // Cancel any outgoing refetches (so they don't overwrite our optimistic update)
        queryClient.cancelQueries(["getSession", args.sessionId]);

        const previousSession = queryClient.getQueryData<
          ISessionResponse<IGenFeatureSessionState>
        >(["getSession", args.sessionId]);

        // Optimistically update to the new value
        if (previousSession) {
          queryClient.setQueryData<ISessionResponse<IGenFeatureSessionState>>(
            ["getSession", args.sessionId],
            {
              ...previousSession,
              status: "pending",
            },
          );
        } else {
          queryClient.setQueryData<ISessionResponse<IGenFeatureSessionState> | null>(
            ["getSession", args.sessionId],
            {
              pk: `GENAI-SESSION#${args.sessionId}`,
              sk: `GENAI-SESSION#${args.sessionId}`,
              feature: "gen",
              state: {
                prompt: args.prompt,
                nSamples: args.nSamples,
                height: args.height,
                width: args.width,
                images: [],
              },
              status: "pending",
            },
          );
        }
      },

      onSettled: () =>
        queryClient.invalidateQueries(["getSession", args.sessionId]),
    },
  );

  return {
    generateImages: mutate,
  };
};

export default useGenerateImages;
