import cv2
import matplotlib.pyplot as plt
from PIL import Image
import os

from dataclasses import dataclass
@dataclass
class Pixel:
    r : int
    g : int
    b : int

@dataclass
class ImagePPM:
    width:int
    height:int
    pixels:list[Pixel]


def showImageFromPath(path:str) -> None:
    img = cv2.imread(path)
    assert img is not None
    # Remember, opencv by default reads images in BGR rather than RGB
    # So we fix that by the following
    #img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    cv2.imwrite(".temp.png",img)
    img2 = Image.open(".temp.png")
    plt.imshow(img2) # type: ignore
    os.remove(".temp.png")

def showImage(imgppm: ImagePPM) -> None:
    saveImage(imgppm, '.temp.ppm')
    img = cv2.imread('.temp.ppm')
    assert img is not None
    # Remember, opencv by default reads images in BGR rather than RGB
    # So we fix that by the following
    #img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    cv2.imwrite(".temp.png",img)
    img2 = Image.open(".temp.png")
    plt.imshow(img2) # type: ignore
    os.remove(".temp.png")
    os.remove('.temp.ppm')

def loadImage(path: str) -> ImagePPM:
    allPixels : list[Pixel] = []
    f = open(path, "r")
    allLines : list[str]= f.readlines()
    usefullLines: list[str] = []
    for line in allLines:
        if line.startswith('#') or len(line)==0:
            pass
        else:
            usefullLines.append(line)

    usefullLines = usefullLines[1:]
    dimensionsAndPixels : list[str] =[]

    for line in usefullLines:
        for w in line.split():
            dimensionsAndPixels.append(w)

    width: int = int(dimensionsAndPixels[0])
    height: int = int(dimensionsAndPixels[1])
    for i in range(3,len(dimensionsAndPixels)-1,3):
        allPixels.append(Pixel(int(dimensionsAndPixels[i]),int(dimensionsAndPixels[i+1]),int(dimensionsAndPixels[i+2])))
    return ImagePPM(width, height ,allPixels)



def saveImage(img: ImagePPM,path: str):
    f = open(path, "w")
    f.write("P3\n")
    f.write("#created by my wonderfull app !\n")
    f.write(f'{img.width} {img.height} 255\n')
    for i in range(0,len(img.pixels)):
        f.write(f'{img.pixels[i].r}\n{img.pixels[i].g}\n{img.pixels[i].b}\n')




def display2DPoints(*lstPoints: list)->None:#type:ignore
    '''This function takes a list of Point2D as parameter and plot them by using the matplotlib module'''
    for lstPoint in lstPoints: # type: ignore
        xSet: list[int] =[]
        ySet: list[int] = []
        for p in lstPoint: # type: ignore
            xSet = xSet + [p.x] # type: ignore
            ySet = ySet + [p.y] # type: ignore
        plt.plot(xSet, ySet, marker='.')# type: ignore
    plt.show()# type: ignore
