// = CODE DESCRIPTION =
// Detect cells on (up to) 3 different channels (A,B and C) and finally on DAPI.
// Check if the centroid of the cell found in the channels (A,B and C)
// is contained in the region of interest (roi) found on the DAPI channel.
// Assigned each cell to one of the possible categories (A, B, C, AB, AC, BC, ABC or N)
// Export downsample image of the annotation with final classification ( and intermediate detection) 
// Specify a annotationName to do the analysis on a defined annotation on all images of a project.
// or annotationName="" so it process all the annotations of an image.
// 
// == INPUTS ==
// QuPath project and one open image with annotation(s)
// Specify the Channels index and the Cell Detection Parameters 
// 
// == OUTPUTS ==
// Detected cells (as detection object) are classified (A, B, C, AB, AC, BC, ABC or N)
// Output Image contains :
//     - the DAPI channel with the classified dectection
//     - the other channels with their intermediate detection
// 
// = DEPENDENCIES =
// QuPath BIOP Extension for Exporting Results: https://github.com/BIOP/qupath-extension-biop
// 
// = INSTALLATION = 
// No installation. Drag & Drop the groovy script on QuPath and Run!
// 
// = AUTHOR INFORMATION =
// Code written by Romain Guiet, Olivier Burri, Nicolas Chiaruttini, EPFL - SV -PTECH - BIOP 
// DATE 2018.09.27
// Updated for QuPath 0.3.2: 2022.02.04
// 
// = COPYRIGHT =
// © All rights reserved. ECOLE POLYTECHNIQUE FEDERALE DE LAUSANNE, Switzerland, BioImaging and Optics Platform (BIOP), 2022
// 
// Licensed under GNU General Public License (GLP) version 3
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
 // This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
// 
// You should have received a copy of the GNU General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
///

//////////////////////////////////////////////////////////////////////////////////////////////
//
//
// Initializations
//
//
//////////////////////////////////////////////////////////////////////////////////////////////
def project =  getProject()

// Get the image data & server
def imageData = getCurrentImageData()
def hierarchy = imageData.getHierarchy()
def server    = imageData.getServerPath()

def annotations = getAnnotationObjects()

// Make sure that there is an output directorz for the resulting image exports
def outputDir = buildFilePath(PROJECT_BASE_DIR, 'Output Images')
mkdirs( outputDir )

// Create QuPath Classifications, with defined colors to highlight simple, double or triple positive detections later
makeCustomPathClasses()
 
// Ensure the ImageJ user interface is showing and clean previously open images
IJExtension.getImageJInstance()
IJ.run("Close All", "")

// custom class that helps output channels with Annotations and Detections as overlays
def tip = new ToImagePlus()

// Declare the  ChDetectors which will be used to define cell detection parameters and store the resulting detections
def chDAPI = new ChDetector()
def chA = new ChDetector()
def chB = new ChDetector()
def chC = new ChDetector()


//////////////////////////////////////////////////////////////////////////////////////////////
//
//
// PARAMETERS THAT YOU HAVE TO SET
//
//
//////////////////////////////////////////////////////////////////////////////////////////////

 
def output_downsample = 2    

def annotationName = "" // specify a name , i.e. "ORB" or set to "" so it processes all the annotations.

// Configure the channels and detectors
chDAPI.channelName = "DAPI"

// specify Channel for :
chA.channelName = "FITC" // channel A
chB.channelName = "CY3"  // channel B OR set to null to only do a channel A analysis
chC.channelName = "CY5"      // channel C OR set to null to only do a channels A&B analysis

// Set all cell detector parameters here for each channel
chDAPI.parameters = DetectorParams.builder().detectionImageFluorescence(chDAPI.channelName)
                                            .requestedPixelSizeMicrons(1)
                                            .backgroundRadiusMicrons(9)
                                            .medianRadiusMicrons(1)
                                            .sigmaMicrons(3)
                                            .threshold(100)
                                            .minAreaMicrons(20)
                                            .maxAreaMicrons(200)
                                            .build()
                                    
chA.parameters = DetectorParams.builder().detectionImageFluorescence(chA.channelName)
                                         .requestedPixelSizeMicrons(1)
                                         .backgroundRadiusMicrons(10)
                                         .medianRadiusMicrons(4)
                                         .sigmaMicrons(5)
                                         .threshold(50)
                                         .minAreaMicrons(80)
                                         .maxAreaMicrons(400)
                                         .build()
                                         
chB.parameters = DetectorParams.builder().detectionImageFluorescence(chB.channelName)
                                         .requestedPixelSizeMicrons(1)
                                         .backgroundRadiusMicrons(8)
                                         .medianRadiusMicrons(5)
                                         .sigmaMicrons(5)
                                         .threshold(30)
                                         .minAreaMicrons(80)
                                         .maxAreaMicrons(400)      
                                         .build()
                                         
chC.parameters = DetectorParams.builder().detectionImageFluorescence(chC.channelName)
                                         .requestedPixelSizeMicrons(1)
                                         .backgroundRadiusMicrons(9)
                                         .medianRadiusMicrons(1)
                                         .sigmaMicrons(3)
                                         .threshold(100)
                                         .minAreaMicrons(20)
                                         .maxAreaMicrons(200)
                                         .build()      


//////////////////////////////////////////////////////////////////////////////////////////////
//
//
// Script Start
//
//
//////////////////////////////////////////////////////////////////////////////////////////////

// chDAPI should be the last one in the list,
// so other detection are performed before
def channelToProcess = [chA, chB, chC,  chDAPI]

// Remove any merged regions before the analysis
def oldMerged = getAnnotationObjects().findAll{ it.getPathClass() == getPathClass("Merged") }
removeObjects(oldMerged, false)

// Make composite region
// Enforce column existence 
annotations.each{ current_Annotation ->
    if ( ( current_Annotation.getPathClass()) =~ annotationName) { // =~ <=> contains , ==~ looks <=> exact match
        
        // Forces Column(s) to exist
        current_Annotation.getMeasurementList().putMeasurement( "Threshold ChDAPI", chDAPI.parameters.threshold )
        current_Annotation.getMeasurementList().putMeasurement( "Threshold ChA", chA.parameters.threshold )
        current_Annotation.getMeasurementList().putMeasurement( "Threshold ChB", chB.parameters.threshold )
        current_Annotation.getMeasurementList().putMeasurement( "Threshold ChC", chC.parameters.threshold )
    }
}

// Cell detection should be done on all annotations as a single "Merged" annotation.
// This way, we only run the cell detection N times (N being the number of channels) instead of NxK times (K being the number of annotations)
def current_Annotation = PathUtils.merge(annotations.findAll{ it.getPathClass() =~ annotationName } ) // =~ <=> contains , ==~ looks <=> exact match})
current_Annotation.setPathClass( getPathClass( "Merged" ) )
addObject(current_Annotation)

fireHierarchyUpdate()

// Work the cell detection magic
setSelectedObject( current_Annotation )

// Get the roi corresponding to the annotation
// Required to process only the cells that are within the annotation
def roi_Annotation = current_Annotation.getROI()

// remove existing childOjbects
if (current_Annotation.hasChildren()) hierarchy.removeObjects(current_Annotation.getChildObjects(), false)

// Run cell detection for each channel, as per user parameters
channelToProcess.each { chDetect ->
    // if childObjects exist => clear them
    if (current_Annotation.hasChildren()) hierarchy.removeObjects(current_Annotation.getChildObjects(), false)       
    
    if ( chDetect.channelName != null ) {

        current_Annotation_className = current_Annotation.getPathClass()
        
        chDetect.runDetection( )

        // store the CellObjects in the custom ChDetector
        chDetect.cellsObjects = current_Annotation.getChildObjects().findAll{ it instanceof PathDetectionObject }

        // when the other channels have been processed
        // the chDetect cooresponds to DAPI
        if (chDetect.channelName == chDAPI.channelName) {
            // Iterate through the detected cells
            chDetect.cellsObjects.each { cell ->
                // get the cell roi and check if it belongs to the current Annotation
                roi_DAPI = cell.getROI()
                if (roi_Annotation.contains(roi_DAPI.getCentroidX(), roi_DAPI.getCentroidY())) {
                    cell.setPathClass( null )

                    def class_string = ""
                    
                    // check if the cell roi in the DAPI channel contains
                    // the centroid of a cell roi detected in the ChA
                    cells_ChA_Match = chA.cellsObjects.find{ roi_DAPI.contains (  it.getROI().getCentroidX() , it.getROI().getCentroidY() ) }
                    
                    if( cells_ChA_Match != null) class_string += "A"
                    
                    // the centroid of a cell roi detected in the ChB
                    cells_ChB_Match = chB.cellsObjects.find{ roi_DAPI.contains (  it.getROI().getCentroidX() , it.getROI().getCentroidY() ) }
                    
                    if( cells_ChB_Match != null) class_string += "B"
                    // the centroid of a cell roi detected in the ChC 
                    
                    cells_ChC_Match = chC.cellsObjects.find{ roi_DAPI.contains (  it.getROI().getCentroidX() , it.getROI().getCentroidY() ) }
                    
                    if( cells_ChC_Match != null) class_string += "C"
                    
                    // otherwise it is a negative cell
                    if ( class_string == "") class_string += "N"

                    // finally we can set the class of that cell
                    cell.setPathClass( getPathClass(class_string) )

                }
            }
        }

        // send an output image of the channel with Annotation and Detections as overlay
        //select the active channels you want to export
        tip.setActiveChannels( [chDetect.channelName] )
        // Get the imagePlus by defining the annotation, downsample factor and if you want the annotation as a ROI
        chDetect.imp = tip.getImagePlus( current_Annotation,  output_downsample, true ).flatten()
        if ( chDetect.imp != null ) chDetect.imp.show()
    }
}

// Reset all channels as active
tip.setActiveChannels( channelToProcess.collect{ it.channelName } )
fireHierarchyUpdate()

 // all the channels have been processed
// we can now show the image corresponding
// Make a stack of it
IJ.run( "Images to Stack", "name=Image title=Flat" )
def last = IJ.getImage()
last.show()
// Save the images
def imageName = getProjectEntry().getImageName() + "_" + current_Annotation_className + ".tiff"
imagePath = buildFilePath( outputDir, imageName )
IJ.saveAs("Tiff", imagePath )
print "Image exported to " +    imagePath

// Save the Measurements
// Get the micrometers name
def um = GeneralTools.micrometerSymbol()

def measureColumnName = ["Name","Class","Parent","ROI","Centroid X "+um,"Centroid Y "+um,"Threshold ChDAPI","Threshold ChA","Threshold ChB","Threshold ChC","Num Annotations","Num Detections","Num A","Num AB","Num ABC","Num AC","Num B","Num BC","Num C","Num N","Area "+um+"^2","Perimeter "+um,"Max length "+um]
def measureName = getProjectEntry().getImageName()+ '.txt'      
def measurePath = new File ( buildFilePath( outputDir, measureName ) ) 
        
// The method below creates a results table
// THE TABLE IS CREATED ONCE and THEN APPENDED
// Utils.sendResultsToFile( measureColumnName, annotations,measurePath )

Results.sendResultsToFile( measureColumnName, annotations, measurePath ) 


println 'Results exported to '+ measurePath


println getProjectEntry().getImageName()+" : DONE ! "
 
// SCRIPT END
     


//////////////////////////////////////////////////////////////////////////////////////////////
//
//
// CUSTOM CLASSes
//
//
//////////////////////////////////////////////////////////////////////////////////////////////

// Create a set of Class, with defined colors to highlight simple, double or triple positive
def makeCustomPathClasses() {
    
    def available = getQuPath().getAvailablePathClasses()

    // Build Classes
    def chAPos = getPathClass( 'A' )
    def chBPos = getPathClass( 'B' )
    def chCPos = getPathClass( 'C' )
    
    def doublePosAB = getPathClass( 'AB' )
    def doublePosAC = getPathClass( 'AC' )
    def doublePosBC = getPathClass( 'BC' )
    
    def triplePos = getPathClass( 'ABC' )
    
    def negative = getPathClass( 'N' )
    
    // Add Classes if they do not exist yet
    if ( !( chAPos in available ) )     available.add( chAPos )
    if ( !( chBPos in available ) )     available.add( chBPos )
    if ( !( chCPos in available ) )     available.add( chCPos )
    
    if ( !( doublePosAB in available ) ) available.add( doublePosAB )
    if ( !( doublePosAC in available ) ) available.add( doublePosAC )
    if ( !( doublePosBC in available ) ) available.add( doublePosBC )
  
    if ( !( triplePos in available ) )   available.add( triplePos )
        
    if ( !( negative in available ) )    available.add( negative )
    
    //Define the colors
    chAPos.setColor( getColorRGB(   0, 255, 127 ) )
    chBPos.setColor( getColorRGB( 255, 127,   0 ) )
    chCPos.setColor( getColorRGB( 127,   0, 255 ) )
    
    doublePosAB.setColor( getColorRGB( 127, 255,   0 ) )
    doublePosAC.setColor( getColorRGB( 0  , 127, 255 ) ) 
    doublePosBC.setColor( getColorRGB( 255,   0, 127 ) )   
    
    triplePos.setColor(getColorRGB( 255, 255, 255 ) )
    negative.setColor( getColorRGB(  50,  50,  50 ) )
}

// to store parameters and result from detection for cell detection
class ChDetector {
    // the channel of interest
    def channelName
    // all the parameters required for " runPlugin('qupath.imagej.detect.nuclei.WatershedCellDetection',...) "
    DetectorParams parameters
    // use to store the cellObjects from the detection
    def cellsObjects
    // use to store the ImagePlus, containing the channel, and the detected cells as overlay
    def imp
    
    def runDetection( ) {

    def detection_str = '{"detectionImage": "' + this.parameters.detectionImageFluorescence + '"' +
                        ', "requestedPixelSizeMicrons": ' + this.parameters.requestedPixelSizeMicrons +
                        ', "backgroundRadiusMicrons": ' + this.parameters.backgroundRadiusMicrons +
                        ', "medianRadiusMicrons": ' + this.parameters.medianRadiusMicrons +
                        ', "sigmaMicrons": ' + this.parameters.sigmaMicrons +
                        ', "minAreaMicrons": ' + this.parameters.minAreaMicrons +
                        ', "maxAreaMicrons": ' + this.parameters.maxAreaMicrons +
                        ', "threshold": ' + this.parameters.threshold +
                        ', "watershedPostProcess": ' + this.parameters.watershedPostProcess +
                        ', "cellExpansionMicrons": ' + this.parameters.cellExpansionMicrons +
                        ', "includeNuclei": ' + this.parameters.includeNuclei +
                        ', "smoothBoundaries": ' + this.parameters.smoothBoundaries +
                        ', "makeMeasurements": ' + this.parameters.makeMeasurements + '}'
                        
    // detect cells using defined parameters
    runPlugin('qupath.imagej.detect.cells.WatershedCellDetection', detection_str )
}


}

// use of @Builder, give access to getter and setter
class DetectorParams{
    def detectionImageFluorescence
    def requestedPixelSizeMicrons
    def backgroundRadiusMicrons
    def medianRadiusMicrons
    def sigmaMicrons
    def minAreaMicrons
    def maxAreaMicrons
    def threshold
    def watershedPostProcess
    def cellExpansionMicrons
    def includeNuclei
    def smoothBoundaries
    def makeMeasurements

    static DetectorParamsBuilder builder() {
        new DetectorParamsBuilder()
    }
}


@Builder(builderStrategy=groovy.transform.builder.ExternalStrategy, forClass=DetectorParams)
class DetectorParamsBuilder {
    DetectorParamsBuilder(){
         detectionImageFluorescence = ""
         requestedPixelSizeMicrons  = 1
         backgroundRadiusMicrons    = 10.0
         medianRadiusMicrons        = 2.5
         sigmaMicrons               = 3
         minAreaMicrons             = 20.0
         maxAreaMicrons             = 400.0
         threshold                  = 100
         watershedPostProcess       = true
         cellExpansionMicrons       = 1
         includeNuclei              = true
         smoothBoundaries           = true
         makeMeasurements           = true
    }
}

// This class handles the conversion from a region to an imagePlus with the chosen channels (no LUT info)
class ToImagePlus {
    def imageData
    def server
    def annot
    def hierarchy
    def viewer
    def channels
    def display

    // Gets some constants that we will need
    public ToImagePlus() {

        imageData = getCurrentImageData()
        server = imageData.getServer()

        viewer = getCurrentViewer()
        display = viewer.getImageDisplay()
        channels = display.availableChannels()
    }

    // Sets the channels active in QuPath, which dictates the ones being exported to IJ
    public void setActiveChannels(activeChannels) {
        
        this.channels.each{ch ->
            this.display.setChannelSelected (ch, activeChannels.findAll{ act -> ch.getName() ==~ /.*$act.*/ }.size > 0 )
        }
    }

    // Exports the given annotation and adds the detections as the overlay
    public ImagePlus getImagePlus(annot, downsample, getAnnotationRoi) {
    def selectedChannels = viewer.getImageDisplay().selectedChannels() as List

    def server_local = ChannelDisplayTransformServer.createColorTransformServer(server, selectedChannels);

        hierarchy = getCurrentHierarchy()
        def request = RegionRequest.createInstance(imageData.getServerPath(), downsample, annot.getROI())
        def pathImage =  IJExtension.extractROIWithOverlay(server_local, annot, hierarchy, request, getAnnotationRoi, viewer.getOverlayOptions())
        def imp = pathImage.getImage()
        if ( getAnnotationRoi ){
            if ( imp.getOverlay() == null) {
                def ovrl = new Overlay(imp.getRoi())
                imp.setOverlay( ovrl )
                //logger.warn("imp.getRoi() : "+imp.getRoi())
            } else{
                imp.getOverlay().add( imp.getRoi() )
            }
            
            // set to the visible range in QuPath for each channel
            selectedChannels.eachWithIndex{ ch, idx ->
                def cmin = ch.getMinDisplay()
                def cmax = ch.getMaxDisplay()
                imp.setC( idx+1 )
                imp.setDisplayRange( cmin, cmax )
            }
                
        }
        return imp
    }
}


// Imports
import ij.*
import qupath.imagej.gui.IJExtension
import qupath.lib.regions.RegionRequest
import qupath.lib.objects.*

import ij.gui.Overlay

// require for builder in ChDetector
import groovy.transform.builder.Builder
import qupath.ext.biop.utils.*

// Now requires BIOP QuPath 0.1.4 (feature forces column of output)
import ch.epfl.biop.qupath.utils.*
import qupath.lib.gui.images.servers.ChannelDisplayTransformServer;

import java.time.*
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;