'''
    SPDX-FileCopyrightText: 2024 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin

    SPDX-License-Identifier: GPL-3.0-or-later
'''


import numpy as np
import sys
import openvino as ov


def interferenceInDifferentThread(model, inputData):
    infer_request = model.create_infer_request()
    
    # Create tensor from external memory
    input_tensor = ov.Tensor(array=inputData)
    # Set input tensor for model with one input
    infer_request.set_input_tensor(input_tensor)
    infer_request.start_async()
    infer_request.wait()
    # Get output tensor for model with one output
    output = infer_request.get_output_tensor()
    output_buffer = output.data
    # output_buffer[] - accessing output tensor data
    return output_buffer




def interference(model, data, partSize, margin=16):
    
    # we assume static model, therefore no smaller parts!
    assert partSize%margin == 0, "Max size of the part must be divisable by the margin"
    

    print(f"Shape of the data = {data}")
    shape = data.shape
    outputData = np.zeros(shape)
    width = shape[2]
    height = shape[3]

    assert (width % margin == 0 and height % margin == 0), "Size of data must be already divisable by the margin"

    
    
    trueDataPartSize = partSize - 2*margin
    partsInData = np.ceil([width/trueDataPartSize, height/trueDataPartSize])
    requiredAdditionalMargin = partsInData*trueDataPartSize
    print(f"(1)", file=sys.stderr)
    #padded = np.pad(data, (margin, margin), 'constant', constant_values=(0, 0))
    # same additional padding in both cases
    print(f"(2), expected size: {data.shape} + ...", file=sys.stderr)
    padded = np.pad(data, ((0, 0), (0, 0), (margin, margin + int(requiredAdditionalMargin[0])), (margin, margin + int(requiredAdditionalMargin[1]))), 'constant', constant_values=(0, 0))

    print(f"(3), padded size = {padded.shape}", file=sys.stderr)
    outputDataPadded = np.pad(data, ((0, 0), (0, 0), (0, int(requiredAdditionalMargin[0])), (0, int(requiredAdditionalMargin[1]))), 'constant', constant_values=(0, 0))
    print(f"padded shape = {padded.shape}")
    print(f"(4)", file=sys.stderr)

    startWMargin = 0

    while True: # x/width

        startHMargin = 0

        while True: # y/height

            # note: those are dimensions for padded, not true data
            # true data dimensions are +margin!

            startWData = startWMargin + margin
            startHData = startHMargin + margin

            endWMargin = startWMargin + partSize
            endHMargin = startHMargin + partSize
            
            endWData = endWMargin - margin
            endHData = endHMargin - margin

            print(f"startWData = {startWData}, endWData = {endWData}, startHData = {startHData}, endHData = {endHData}")
            print(f"startWMargin = {startWMargin}, endWMargin = {endWMargin}, startHMargin = {startHMargin}, endHMargin = {endHMargin}")
            
            
            print(f"Part is = [{startWMargin}:{endWMargin}, {startHMargin}:{endHMargin}]")
            part = padded[:, :, startWMargin:endWMargin, startHMargin:endHMargin]
            print(f"part shape = {part.shape}")


            print(f"Output data info is = [{(startWData - margin)}:{(endWData - margin)}, {(startHData - margin)}:{(endHData - margin)}]")

            
            
            #result = model(part)[0]
            result = interferenceInDifferentThread(model, part)



            print(f"Result shape is {result.shape}")

            result2 = result[:, :, margin:(partSize - margin), margin:(partSize - margin)]
            print(f"Result2 shape = {result2.shape}, info was: {margin}:{(partSize - margin)}, {margin}:{(partSize - margin)}")
            outputDataSubarray = outputDataPadded[:, :, (startWData - margin):(endWData - margin), (startHData - margin):(endHData - margin)]
            print(f"output data subarray shape = {outputDataSubarray.shape}, info was: ^")

            # NOTE: the output Data might need padding too?
            outputDataPadded[:, :, (startWData-margin):(endWData-margin), (startHData-margin):(endHData-margin)] = result[:, :, margin:(partSize - margin), margin:(partSize - margin)]
            
            
            print(f"Finished processing one part.")

            if endHMargin >= height + 2*margin:
                break
                
            startHMargin += trueDataPartSize

        if endWMargin >= width + 2*margin:
            break
        
        startWMargin += trueDataPartSize
    
    # shed the additional padding
    outputData = outputDataPadded[:, :, 0:width, 0:height]

    print(f"Finished processing all parts.")
    return outputData
