/**
 * EpiCurrents Viewer SezNet ONNX loader.
 * @package    epicurrents-viewer
 * @copyright  2023 Sampsa Lohi
 * @license    MIT
 */
import Log from 'scoped-ts-log';
import { GenericOnnxLoader } from './GenericOnnxLoader';
import { concatFloat32Arrays } from 'LIB/util/signal';
import SETTINGS from 'CONFIG/Settings';
const SCOPE = 'SezNetOnnxLoader';
const EPOCH_LEN = 16;
const RUN_RANGE = 160;
export class SezNetOnnxLoader extends GenericOnnxLoader {
    _nextRunRange = [];
    _processedRange = [0, 0];
    _source = null;
    constructor(model) {
        super(model);
        this._availableModels.set('seznet-1', {
            name: 'SezNet-1',
            supportedStudyTypes: ['sig:eeg'],
            worker: new Worker(new URL("LIB/workers/SezNet1Worker.ts", import.meta.url))
        });
        this._availableModels.set('seznet-2', {
            name: 'SezNet-2',
            supportedStudyTypes: ['sig:eeg'],
            worker: new Worker(new URL("LIB/workers/SezNet2Worker.ts", import.meta.url))
        });
    }
    findNextRunRange() {
        if (!this._source) {
            Log.error(`Cannot find next run range without an active data source!`, SCOPE);
            return [];
        }
        Log.debug(`Finding next run range for signal cache ${this._source.signalCacheStatus.join('-')}.`, SCOPE);
        if (this._processedRange[0] === 0 &&
            this._processedRange[1] >= this._source.totalDuration) {
            // All processed
            if (this._nextRunRange) {
                this._nextRunRange = [];
            }
            if (this._runInProgress) {
                this.runInProgress = false;
            }
            Log.debug(`ONNX run completed.`, SCOPE);
            return this._nextRunRange;
        }
        if (this._nextRunRange &&
            ((this._nextRunRange[1] &&
                this._nextRunRange[0] === this._processedRange[1]) ||
                (this._processedRange[1] === this._source.totalDuration &&
                    this._nextRunRange[1] === this._processedRange[0]))) {
            Log.debug(`Next run range ${this._nextRunRange.join('-')} already set.`, SCOPE);
            return this._nextRunRange;
        }
        if (this._runInProgress) {
            if (!this._nextRunRange) {
                this._nextRunRange = [0, 0];
            }
            // First, process forward if possible
            if (this._source.signalCacheStatus[1] >= this._processedRange[1] + RUN_RANGE ||
                (this._source.signalCacheStatus[1] > this._processedRange[1] &&
                    this._source.signalCacheStatus[1] === this._source.totalDuration)) {
                this._nextRunRange[0] = this._processedRange[1];
                if (this._source.totalDuration > this._processedRange[1] + RUN_RANGE) {
                    this._nextRunRange[1] = this._processedRange[1] + RUN_RANGE;
                }
                else {
                    const finalEpochs = (this._source.totalDuration - this._processedRange[1]) / EPOCH_LEN;
                    if (finalEpochs % 1) {
                        Log.debug(`Final epoch is ${Math.round((finalEpochs % 1) * EPOCH_LEN)} seconds long, padding end with zeroes.`, SCOPE);
                    }
                    this._nextRunRange[1] = this._processedRange[1] + Math.ceil(finalEpochs) * EPOCH_LEN;
                }
            }
            else if (this._source.signalCacheStatus[0] < this._processedRange[0] - RUN_RANGE ||
                this._processedRange[1] === this._source.totalDuration) {
                this._nextRunRange[0] = Math.max(0, this._processedRange[0] - RUN_RANGE);
            }
        }
        else if (this._source.signalCacheStatus[1] - this._source.signalCacheStatus[0] >= RUN_RANGE) {
            // Mark the first part to be processed
            this._nextRunRange[0] = this._source.signalCacheStatus[0];
            this._nextRunRange[1] = this._source.signalCacheStatus[0] + Math.floor((this._source.signalCacheStatus[1] - this._source.signalCacheStatus[0])
                / RUN_RANGE) * RUN_RANGE;
        }
        else if (this._source.signalCacheStatus[1] < this._source.totalDuration) {
            Log.debug(`Cached signal duration is less than minimum run range, waiting for more signals...`, SCOPE);
            return [];
        }
        else {
            // Total signal length is shorter than run range
            const rangeToProcess = this._source.signalCacheStatus[1] - this._source.signalCacheStatus[0];
            if (rangeToProcess % EPOCH_LEN) {
                Log.debug(`Final epoch is ${rangeToProcess % EPOCH_LEN} seconds long, padding end with zeroes.`, SCOPE);
            }
            this._nextRunRange = [...this._source.signalCacheStatus];
        }
        Log.debug(`Next run range set to ${this._nextRunRange.join('-')}.`, SCOPE);
        return this._nextRunRange;
    }
    async loadModel(model) {
        // Override this method to reset progress properties here as well
        if (model === this._activeModel) {
            return;
        }
        if (this._runInProgress) {
            this.resetProgress();
        }
        return super.loadModel(model);
    }
    async processChannelSignals(range, channels) {
        const process = async (samplingRate, startIndex, ...samples) => {
            this._progress.startIndex = startIndex;
            const callbacks = {
                resolve: (predictions) => { },
                reject: (reason) => { }
            };
            const response = this._commissionWorker('run', callbacks, new Map([
                ['samples', samples],
                ['samplingRate', samplingRate],
            ]));
            return response.promise;
        };
        if (!this._source) {
            Log.warn(`Process channel signals called without an active data source.`, SCOPE);
            return;
        }
        const montage = this._source.activeMontage;
        if (!montage) {
            // Raw signal data cannot be normalized
            Log.warn(`ONNX analysis is not available to raw recorded signal data.`, SCOPE);
            return;
        }
        const positives = new Map();
        const segNum = Math.floor((range[1] - range[0]) / EPOCH_LEN);
        const segComplete = channels.length * ((this._processedRange[1] - this._processedRange[0]) / EPOCH_LEN);
        if (segNum > 0) {
            Log.debug(`Analyzing signals for range ${range[0]}-${range[1]}.`, SCOPE);
            const response = await montage.getAllSignals([range[0], range[0] + segNum * EPOCH_LEN]);
            if (!response) {
                Log.warn(`Aborting analysis; invalid signal response.`, SCOPE);
                return;
            }
            if (!this._runInProgress) {
                Log.debug(`Aborting analysis; run paused. [1]`, SCOPE);
                return;
            }
            for (let i = 0; i < channels.length; i++) {
                const chan = channels[i];
                const reqLen = (segNum * EPOCH_LEN) * chan.samplingRate;
                const sig = response.signals[i].length >= reqLen
                    ? response.signals[i]
                    : concatFloat32Arrays(response.signals[i], new Float32Array(reqLen - response.signals[i].length).fill(0));
                for (let j = 0; j < segNum; j++) {
                    const start = j * EPOCH_LEN * chan.samplingRate;
                    const end = (j + 1) * EPOCH_LEN * chan.samplingRate;
                    const probs = await process(chan.samplingRate, segComplete + i * segNum + j, sig.slice(start, end));
                    if (!this._runInProgress) {
                        Log.debug(`Aborting analysis; run paused. [2]`, SCOPE);
                        return;
                    }
                    if (probs) {
                        if (probs[0][0] < probs[0][1]) {
                            const prevs = positives.get(j * EPOCH_LEN);
                            if (prevs) {
                                prevs.push([i, probs[0][1]]);
                            }
                            else {
                                positives.set(j * EPOCH_LEN, [[i, probs[0][1]]]);
                            }
                        }
                    }
                }
            }
            // Add positives as signal highlights
            for (const pos of positives.entries()) {
                for (let j = 0; j < pos[1].length; j++) {
                    const chan = pos[1][j][0];
                    const value = pos[1][j][1];
                    montage.addHighlights(this._id, {
                        background: true,
                        channels: [chan],
                        collarLength: EPOCH_LEN / 2,
                        end: Math.min(pos[0] + range[0] + EPOCH_LEN, this._source.totalDuration),
                        hasNext: false,
                        hasPrevious: false,
                        label: `${this._activeModel}: ${(value * 100).toFixed()} % confidence.`,
                        opacity: value,
                        start: pos[0] + range[0],
                        type: 'detection',
                        value: value,
                        visible: true,
                    });
                }
            }
            //this.updateHighlights()
            this._processedRange[0] = Math.min(this._processedRange[0], range[0]);
            this._processedRange[1] = Math.max(this._processedRange[1], range[1]);
            this.onPropertyUpdate('results');
        }
    }
    resetProgress() {
        this._source?.activeMontage?.removeAllHighlightsFrom(this._id);
        super.resetProgress();
        this._processedRange = [0, 0];
        this._nextRunRange = [];
        this.onPropertyUpdate('results');
    }
    async run() {
        if (this._runInProgress) {
            Log.error(`Cannot start a new run while an active run is in progress.`, SCOPE);
            return;
        }
        // Start a new run
        if (!this._source) {
            Log.error(`Run called without an active data source.`, SCOPE);
            return;
        }
        const montage = this._source.activeMontage;
        if (!montage) {
            // Raw signal data cannot be normalized
            Log.error(`ONNX analysis is not available to raw recorded signal data.`, SCOPE);
            return;
        }
        const channels = montage.channels.filter(c => c?.type === 'eeg');
        if (!montage.highlights.get(this._id)) {
            // Set up highlight context for displaying the detections
            montage.addHighlightContext(this._id, {
                highlights: [],
                plotDisplay: {
                    color: [1, 0.5, 0.5, 0.5],
                    type: 'background',
                },
                naviDisplay: {
                    algorithms: ['log'],
                    color: [1, 0.5, 0.5, 1],
                    interval: EPOCH_LEN,
                    ref: 1,
                    type: 'curve',
                },
                visible: true,
            });
        }
        SETTINGS.eeg.annotations.idColors[`ONNX-${this._activeModel}`] = [1, 0.75, 0.75, 0.75];
        const segTotal = Math.ceil(this._source.dataDuration / EPOCH_LEN);
        this.setProgressTarget(channels.length * segTotal);
        this.runInProgress = true;
        Log.debug(`Starting an ONNX analysis run for ${this._progress.target} data segments.`, SCOPE);
        let range = this._nextRunRange;
        if (!range?.length || !(range[1] - range[0])) {
            // Try if some range has not been analyzed yet
            range = this.findNextRunRange();
            if (!range) {
                Log.debug(`No more ranges to process.`, SCOPE);
                return;
            }
        }
        while (range && range[1] > range[0]) {
            await this.processChannelSignals(range, channels);
            if (!this._runInProgress) {
                break;
            }
            const highlights = (this._source.activeMontage?.highlights.get(this._id)?.highlights || []);
            for (const hl of highlights) {
                if (!hl.hasPrevious) {
                    for (const prevHl of highlights) {
                        if (prevHl.end === hl.start && prevHl.channels[0] === hl.channels[0]) {
                            hl.hasPrevious = true;
                            break;
                        }
                        else if (prevHl.start >= hl.start) {
                            break;
                        }
                    }
                }
                if (!hl.hasNext) {
                    for (const nextHl of highlights) {
                        if (nextHl.start === hl.end && nextHl.channels[0] === hl.channels[0]) {
                            hl.hasNext = true;
                            break;
                        }
                        else if (nextHl.start > hl.end) {
                            break;
                        }
                    }
                }
            }
            Log.debug(`Range ${this._processedRange[0]}-${this._processedRange[1]} analyzed, checking for next range to analyze.`, SCOPE);
            range = this.findNextRunRange();
        }
    }
    setNextRunRange(range) {
        if (!this._source) {
            Log.error(`Cannot set next run range until source recording has been set.`, SCOPE);
            return;
        }
        if (range.length !== 2) {
            Log.error(`Run range must have exactly two items [start, end].`, SCOPE);
            return;
        }
        if (range[0] < 0 || range[0] >= range[1] || range[1] >= this._source.totalDuration) {
            Log.error(`Run range ${range[0]}-${range[1]} is not valid.`, SCOPE);
            return;
        }
        this._nextRunRange = range;
    }
    setSourceResource(resource) {
        if (!super.setSourceResource(resource, SCOPE)) {
            return false;
        }
        this._source?.addPropertyUpdateHandler('active-montage', this.resetProgress.bind(this), SCOPE);
        this._source?.addPropertyUpdateHandler('signal-cache-status', this.findNextRunRange.bind(this), SCOPE);
        return true;
    }
}
