/*
 * Decompiled with CFR 0.152.
 */
package org.nuxeo.ai.bulk;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.nuxeo.ai.bulk.ExportBulkProcessed;
import org.nuxeo.ai.bulk.RecordWriter;
import org.nuxeo.ai.model.export.DatasetExportService;
import org.nuxeo.ai.services.AIComponent;
import org.nuxeo.ecm.core.api.Blob;
import org.nuxeo.ecm.core.api.CloseableCoreSession;
import org.nuxeo.ecm.core.api.CoreInstance;
import org.nuxeo.ecm.core.api.CoreSession;
import org.nuxeo.ecm.core.api.DocumentModel;
import org.nuxeo.ecm.core.api.NuxeoException;
import org.nuxeo.ecm.core.bulk.BulkService;
import org.nuxeo.ecm.core.bulk.action.computation.AbstractBulkComputation;
import org.nuxeo.ecm.core.bulk.message.BulkCommand;
import org.nuxeo.ecm.core.bulk.message.BulkStatus;
import org.nuxeo.lib.stream.codec.Codec;
import org.nuxeo.lib.stream.computation.AbstractComputation;
import org.nuxeo.lib.stream.computation.ComputationContext;
import org.nuxeo.lib.stream.computation.Record;
import org.nuxeo.runtime.api.Framework;
import org.nuxeo.runtime.codec.CodecService;
import org.nuxeo.runtime.transaction.TransactionHelper;

public class DataSetExportStatusComputation
extends AbstractComputation {
    private static final Log log = LogFactory.getLog(DataSetExportStatusComputation.class);
    protected final Set<String> writerNames;
    protected Map<String, Long> counters = new HashMap<String, Long>();

    public DataSetExportStatusComputation(String name, Set<String> writerNames) {
        super(name, 1, 1);
        this.writerNames = writerNames;
    }

    public static Codec<ExportBulkProcessed> getExportStatusCodec() {
        return ((CodecService)Framework.getService(CodecService.class)).getCodec("avro", ExportBulkProcessed.class);
    }

    public static void updateExportStatusProcessed(ComputationContext context, String commandId, long processed) {
        ExportBulkProcessed exportStatus = new ExportBulkProcessed(commandId, processed);
        context.produceRecord("o1", commandId, DataSetExportStatusComputation.getExportStatusCodec().encode((Object)exportStatus));
    }

    public static boolean isTraining(String name) {
        return "training".equals(name);
    }

    public void processRecord(ComputationContext context, String inputStreamName, Record record) {
        ExportBulkProcessed exportStatus = (ExportBulkProcessed)DataSetExportStatusComputation.getExportStatusCodec().decode(record.getData());
        BulkService service = (BulkService)Framework.getService(BulkService.class);
        if (this.isEndOfBatch(exportStatus)) {
            BulkCommand command = service.getCommand(exportStatus.getCommandId());
            for (String name : this.writerNames) {
                RecordWriter writer = ((AIComponent)Framework.getService(AIComponent.class)).getRecordWriter(name);
                if (writer == null) {
                    throw new NuxeoException(String.format("Unable to find record writer: %s", name));
                }
                if (!writer.exists(exportStatus.getCommandId())) continue;
                try {
                    Optional blob = writer.complete(exportStatus.getCommandId());
                    blob.ifPresent(theBlob -> {
                        if (command != null) {
                            this.updateCorpusDocument(exportStatus, command, (Blob)theBlob, DataSetExportStatusComputation.isTraining(name));
                        } else {
                            log.warn((Object)String.format("The bulk command with id %s is missing.  Unable to save blob info for %s %s.", exportStatus.getCommandId(), name, theBlob.getDigest()));
                        }
                    });
                }
                catch (IOException e) {
                    throw new NuxeoException(String.format("Unable to complete action %s", exportStatus.getCommandId()), (Throwable)e);
                }
            }
            this.counters.remove(exportStatus.getCommandId());
        }
        this.updateDelta(exportStatus.getCommandId(), exportStatus.getProcessed());
        AbstractBulkComputation.updateStatusProcessed((ComputationContext)context, (String)exportStatus.getCommandId(), (long)exportStatus.getProcessed());
        context.askForCheckpoint();
    }

    protected void updateCorpusDocument(ExportBulkProcessed exportStatus, BulkCommand command, Blob theBlob, boolean isTraining) {
        TransactionHelper.runInTransaction(() -> {
            try (CloseableCoreSession session = CoreInstance.openCoreSession((String)command.getRepository(), (String)command.getUsername());){
                DocumentModel document = ((DatasetExportService)Framework.getService(DatasetExportService.class)).getCorpusDocument((CoreSession)session, command.getId());
                if (document != null) {
                    document.setPropertyValue("ai_corpus:documents_count", (Serializable)Long.valueOf(exportStatus.getProcessed() + this.getCount(exportStatus.getCommandId())));
                    document.setPropertyValue(isTraining ? "ai_corpus:training_data" : "ai_corpus:evaluation_data", (Serializable)theBlob);
                    session.saveDocument(document);
                } else {
                    log.warn((Object)String.format("Unable to save blob %s for command id %s.", theBlob.getDigest(), exportStatus.getCommandId()));
                }
            }
        });
    }

    protected Long getCount(String commandId) {
        return this.counters.get(commandId);
    }

    protected void updateDelta(String commandId, long processed) {
        this.counters.computeIfPresent(commandId, (s, aLong) -> processed + aLong);
    }

    protected boolean isEndOfBatch(ExportBulkProcessed exportStatus) {
        BulkStatus status = ((BulkService)Framework.getService(BulkService.class)).getStatus(exportStatus.getCommandId());
        Long processed = this.getCount(exportStatus.getCommandId());
        if (processed == null) {
            processed = 0L;
            this.counters.put(exportStatus.getCommandId(), processed);
        }
        return processed + exportStatus.getProcessed() >= status.getTotal();
    }
}

