library("DESeq2")
library("EnhancedVolcano")
library("ggplot2")
library("RColorBrewer")
library("plyr")
library("dplyr")
library("splitstackshape")

# Set the directory path for saving SVG files
figureDir <- "./figures/"

# Create the directory if it doesn't exist
if (!dir.exists(figureDir)) {
  dir.create(figureDir)
}

directory <- getwd()
sampleFiles <- grep("counts", list.files(directory), value = TRUE)
sampleCondition <- sub("(.*counts).*", "\\1", sampleFiles)
sampleTable <- data.frame(
  sampleName = sampleFiles,
  fileName = sampleFiles,
  condition = sampleCondition
)

ddsHTSeq <- DESeqDataSetFromHTSeqCount(
  sampleTable = sampleTable,
  directory = directory,
  design = ~condition
)
dds <- DESeq(ddsHTSeq)
res <- results(dds)

# Perform rlog transformation for PCA plot
rld <- rlogTransformation(dds)

# Set the file names for the SVGs
pcaPlotFile <- paste0(figureDir, "pcaplot.svg")
volcanoPlotFile <- paste0(figureDir, "volcanoplot.svg")
upLowestCatFile <- paste0(figureDir, "overexpressed_lowest_category.svg")
upHighestCatFile <- paste0(figureDir, "overexpressed_highest_category.svg")
downLowestCatFile <- paste0(figureDir, "repressed_lowest_category.svg")
downHighestCatFile <- paste0(figureDir, "repressed_highest_category.svg")

# PCA plot
svg(pcaPlotFile)
plotPCA(rld, intgroup = c("condition"))
dev.off()

# Volcano plot
svg(volcanoPlotFile)
EnhancedVolcano(res, lab = NA, x = 'log2FoldChange', y = 'padj', xlim = c(-5, 8))
dev.off()

# Pie charts
# Upregulated genes - lowest category
svg(upLowestCatFile, width = 20, height = 10)
piechartdata_up <- as.data.frame(sigupkegg)
colourCount_up <- length(unique(piechartdata_up$lowest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
p <- ggplot(piechartdata_up, aes(x = "", fill = lowest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_up))
blank_theme <- theme_minimal() +
  theme(
    axis.title.x = element_blank(),
    axis.title.y = element_blank(),
    panel.border = element_blank(),
    panel.grid = element_blank(),
    axis.ticks = element_blank(),
    plot.title = element_text(size = 14, face = "bold")
  )
p + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Upregulated genes - highest category
svg(upHighestCatFile, width = 20, height = 10)
colourCount_up_high <- length(unique(piechartdata_up$highest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
q <- ggplot(piechartdata_up, aes(x = "", fill = highest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_up_high))
q + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Downregulated genes - lowest category
svg(downLowestCatFile, width = 20, height = 10)
piechartdata_down <- as.data.frame(sigdownkegg)
colourCount_down <- length(unique(piechartdata_down$lowest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
r <- ggplot(piechartdata_down, aes(x = "", fill = lowest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_down))
r + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Downregulated genes - highest category
svg(downHighestCatFile, width = 20, height = 10)
colourCount_down_high <- length(unique(piechartdata_down$highest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
s <- ggplot(piechartdata_down, aes(x = "", fill = highest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_down_high))
s + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Upregulated genes - lowest category
svg(upLowestCatFile, width = 20, height = 10)
piechartdata_up <- as.data.frame(sigupkegg)
colourCount_up <- length(unique(piechartdata_up$lowest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
p <- ggplot(piechartdata_up, aes(x = "", fill = lowest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_up))
blank_theme <- theme_minimal() +
  theme(
    axis.title.x = element_blank(),
    axis.title.y = element_blank(),
    panel.border = element_blank(),
    panel.grid = element_blank(),
    axis.ticks = element_blank(),
    plot.title = element_text(size = 14, face = "bold")
  )
p + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Upregulated genes - highest category
svg(upHighestCatFile, width = 20, height = 10)
colourCount_up_high <- length(unique(piechartdata_up$highest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
q <- ggplot(piechartdata_up, aes(x = "", fill = highest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_up_high))
q + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Downregulated genes - lowest category
svg(downLowestCatFile, width = 20, height = 10)
piechartdata_down <- as.data.frame(sigdownkegg)
colourCount_down <- length(unique(piechartdata_down$lowest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
r <- ggplot(piechartdata_down, aes(x = "", fill = lowest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_down))
r + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Downregulated genes - highest category
svg(downHighestCatFile, width = 20, height = 10)
colourCount_down_high <- length(unique(piechartdata_down$highest_category))
getPalette <- colorRampPalette(brewer.pal(12, "Paired"))
s <- ggplot(piechartdata_down, aes(x = "", fill = highest_category)) +
  geom_bar(width = 1) + coord_polar("y") +
  scale_fill_manual(values = getPalette(colourCount_down_high))
s + blank_theme + theme(axis.text.x = element_blank())
dev.off()

# Save the summaries to CSV files
write.csv(sorted_sum_up_low, file = "./kegg_analysis/overexpressed_kegg_summary_catlow.csv")
write.csv(sorted_sum_up_high, file = "./kegg_analysis/overexpressed_kegg_summary_cathigh.csv")
write.csv(sorted_sum_down_low, file = "./kegg_analysis/repressed_kegg_summary_catlow.csv")
write.csv(sorted_sum_down_high, file = "./kegg_analysis/repressed_kegg_summary_cathigh.csv")

