diff --git a/README.md b/README.md
index 087071a..810deec 100644
--- a/README.md
+++ b/README.md
@@ -30,6 +30,12 @@ Install dependencies:
pip install -r requirements.txt
```
+Install pyphylon package:
+
+```bash
+pip install -e .
+```
+
## Usage
Provide a quick start example of how to use the tool:
diff --git a/examples/6a_phylon_location_plotting.ipynb b/examples/6a_phylon_location_plotting.ipynb
new file mode 100644
index 0000000..42abb66
--- /dev/null
+++ b/examples/6a_phylon_location_plotting.ipynb
@@ -0,0 +1,63 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pyphylon.plotting_util\n",
+ "import pyphylon.plotting"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%matplotlib inline\n",
+ "%config InlineBackend.figure_format = 'svg'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "PATH_TO_DATA = './data/'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "REFERENCE_STRAIN = ''"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pyphylon",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pyphylon/plotting.py b/pyphylon/plotting.py
index e69de29..aa0b035 100644
--- a/pyphylon/plotting.py
+++ b/pyphylon/plotting.py
@@ -0,0 +1,541 @@
+import logging
+import re
+import urllib
+from io import StringIO
+import pandas as pd
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+import gzip
+import pickle
+from tqdm.notebook import tqdm, trange
+import multiprocessing
+from IPython.display import display, HTML
+import itertools
+
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+
+
+anchor_gene_color = '#6495ED'
+background_gene_color = '#F0F0F0'
+phylon_gene_location = '#FFC96F'
+Normal_Order_color = '#C6EBC5'
+Inversion_Color = '#FA7070'
+unique_gene_color = '#FF0000'
+
+
+def plot_gene_sets_with_common_subset_proportional(gene_lists, common_subset, strain_names, X, Y):
+ # Number of plots
+ num_plots = len(gene_lists)
+
+ # Create a subplot figure, with rows equal to the number of gene lists
+ fig = make_subplots(rows=num_plots, cols=1, subplot_titles=strain_names)
+
+ # Maximum number of genes in the lists to define x-axis range
+ max_length = max(len(gene_list) for gene_list in gene_lists) # Used for red lines
+
+ # Add a bar for legend for each color once
+ fig.add_trace(go.Bar(
+ x=[0], y=[0], marker_color='black', name='Phylon', showlegend=True
+ ))
+ fig.add_trace(go.Bar(
+ x=[0], y=[0], marker_color='lightgrey', name='Other Genes', showlegend=True
+ ))
+
+ # Plot each gene list in its own subplot
+ for index, gene_list in enumerate(gene_lists, start=1):
+ # Normalize x_values to use full width of subplot by distributing genes evenly
+ x_values = [i * (max_length - 1) / (max(len(gene_list) - 1, 1)) for i in range(len(gene_list))]
+ # Colors based on membership in the common subset
+ colors = ['black' if gene in common_subset else 'lightgrey' for gene in gene_list]
+
+ # Add bars with calculated x positions
+ fig.add_trace(go.Bar(
+ x=x_values,
+ y=[1] * len(gene_list),
+ marker_color=colors,
+ width=0.9, # Adjust the width to fit within subplot without touching red lines
+ showlegend=False
+ ), row=index, col=1)
+
+ # Add continuous red boundary lines across all subplots
+ # Ensure lines are outside the range of x-values used for gene bars
+ fig.add_shape(type="line",
+ x0=-0.5, y0=0, x1=-0.5, y1=num_plots-(len(gene_lists)-1), # Start boundary line
+ line=dict(color="red", width=6),
+ xref="x", yref="paper")
+
+
+ fig.add_shape(type="line",
+ x0=max_length - 0.5, y0=0, x1=max_length - 0.5, y1=num_plots-(len(gene_lists)-1), # End boundary line
+ line=dict(color="red", width=6),
+ xref="x", yref="paper")
+
+ fig.add_annotation(x=0, y=num_plots-(len(gene_lists)-1)+0.05, xref="paper", yref="paper", font=dict(
+ size=16,
+ color="red"
+ ), text=f"Anchor Gene {X}", showarrow=False)
+
+ fig.add_annotation(x=1, y=num_plots-(len(gene_lists)-1)+0.05, xref="paper", yref="paper", font=dict(
+ size=16,
+ color="red"
+ ), text=f"Anchor Gene {Y}", showarrow=False)
+
+ # Update layout for better view
+ fig.update_layout(
+ title='',
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
+ height=300 * num_plots, # Adjust height based on number of plots
+ bargap=0 # Remove any gap between bars
+ )
+
+ fig.show()
+
+
+import plotly.graph_objects as go
+
+def plot_circular_genome_combined_with_eggnog(list1, list2, title, strain, df_eggnog, show_legend=False):
+ # Ensure the gene list starts with the number 1
+ if 1 in list1:
+ while list1[0] != 1:
+ list1 = list1[-1:] + list1[:-1]
+
+ # Calculate the number of genes
+ num_genes = len(list1)
+
+ # Define the angles for each gene, ensuring gene '1' is at the top center (90 degrees)
+ angles = [(-i * 360 / num_genes + 90) % 360 for i in range(num_genes)]
+
+ # Define the radial range for the bars
+ inner_radius = 0.8
+ outer_radius = 1.0
+
+ outer_gap = 0.1 # Define the gap size between the inner and outer rings
+
+ # Adjust base radius to create a gap
+ outer_base_radius = outer_radius + outer_gap
+
+ # Calculate the height of each bar
+ bar_height = outer_radius - inner_radius
+
+ # Create lists for the radii of each type of gene
+ r_black = [bar_height] * num_genes
+ r_blue = [bar_height if gene in list2 and not isinstance(gene, int) else 0 for gene in list1]
+ r_red = [bar_height if isinstance(gene, int) else 0 for gene in list1]
+
+ # Create hover text for each gene
+ hover_text = []
+ for gene in list1:
+ if gene in df_eggnog.index:
+ gene_info = df_eggnog.loc[gene]
+ hover_text.append(
+ f"{gene}
COG_category: {gene_info['COG_category']}
Preferred_name: {gene_info['Preferred_name']}
PFAMs: {gene_info['PFAMs']}
BiGG_Reaction: {gene_info['BiGG_Reaction']}"
+ )
+ else:
+ hover_text.append(f"{gene}
COG_category: N/A
Preferred_name: N/A
PFAMs: N/A
BiGG_Reaction: N/A")
+
+ # Create the plot
+ fig = go.Figure()
+
+ fig.add_trace(go.Barpolar(
+ r=r_black,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius] * num_genes,
+ marker_color=background_gene_color,
+ marker_line_color=background_gene_color,
+ opacity=0.7,
+ name='Genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Add blue bars for genes in list2
+ fig.add_trace(go.Barpolar(
+ r=r_blue,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius if gene in list2 and not isinstance(gene, int) else 0 for gene in list1],
+ marker_color=phylon_gene_location,
+ marker_line_color=phylon_gene_location,
+ opacity=0.7,
+ name=f'{title} Phylon genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Add red bars for anchor genes (integers)
+ fig.add_trace(go.Barpolar(
+ r=r_red,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius if isinstance(gene, int) else 0 for gene in list1],
+ marker_color=anchor_gene_color,
+ marker_line_color=anchor_gene_color,
+ opacity=0.7,
+ name='Anchor genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Filter out non-integer genes for the outer ring
+ int_genes = [gene for gene in list1 if isinstance(gene, int)]
+ num_int_genes = len(int_genes)
+
+ if num_int_genes > 0:
+ # Calculate segment angles and widths based on positions of integer genes
+ segment_base_angles = []
+ segment_widths = []
+ outer_colors = []
+ outer_hover_text = []
+
+ for i in range(num_int_genes):
+ current_gene = int_genes[i]
+ next_gene = int_genes[(i + 1) % num_int_genes]
+
+ current_gene_index = list1.index(current_gene)
+ next_gene_index = list1.index(next_gene)
+
+ if next_gene_index < current_gene_index:
+ next_gene_index += num_genes # handle wrapping around the list
+
+ # Calculate the angular width of the segment
+ segment_width = (next_gene_index - current_gene_index) * (360 / num_genes)
+ segment_widths.append(segment_width)
+
+ # The base angle is the angle of the current gene
+ base_angle = angles[current_gene_index] - segment_width / 2
+ segment_base_angles.append(base_angle)
+
+ # Determine the color of the segment
+ if (current_gene < next_gene) or (i == num_int_genes - 1 and current_gene > next_gene):
+ outer_colors.append(Normal_Order_color) # Increasing order
+ else:
+ outer_colors.append(Inversion_Color) # Decreasing order
+
+ outer_hover_text.append(f'{current_gene}-{next_gene}')
+
+ fig.add_trace(go.Barpolar(
+ r=[bar_height] * num_int_genes,
+ theta=segment_base_angles,
+ width=segment_widths,
+ base=[outer_base_radius] * num_int_genes,
+ marker_color=outer_colors,
+ marker_line_color=outer_colors,
+ opacity=0.7,
+ name='Outer Ring',
+ text=outer_hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Update layout to add a white circle in the center
+ fig.update_layout(
+ polar=dict(
+ radialaxis=dict(visible=False, range=[0, outer_base_radius + bar_height]),
+ angularaxis=dict(visible=False),
+ bgcolor="white"
+ ),
+ showlegend=show_legend,
+ paper_bgcolor='white',
+ plot_bgcolor='white'
+ )
+
+ return fig
+
+def plot_combined_circular_genomes_with_variaton(figures, titles, phylon, figsize=(1500, 500), save_path=None, dpi=300, show = False):
+ # Create a subplot layout with 1 row and len(figures) columns
+ fig = make_subplots(rows=1, cols=len(figures), subplot_titles=titles, specs=[[{'type': 'polar'}]*len(figures)])
+
+ for i, figure in enumerate(figures):
+ for trace in figure['data']:
+ trace.showlegend = False # Hide legend for individual traces
+ fig.add_trace(trace, row=1, col=i+1)
+
+ # Add a single legend manually
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=background_gene_color, symbol='square', size=12), name='Genes', showlegend=True
+ ))
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=phylon_gene_location, symbol='square', size=12), name=f'{phylon} Phylon Genes', showlegend=True
+ ))
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=anchor_gene_color, symbol='square', size=12), name='Anchor Genes', showlegend=True
+ ))
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color='black', symbol='square', size=0), name='', showlegend=True
+ ))
+
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=Normal_Order_color, symbol='square', size=12), name='No variation', showlegend=True
+ ))
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=Inversion_Color, symbol='square', size=12), name='Inversion', showlegend=True
+ ))
+ # Update the layout
+ fig.update_layout(
+ title = f'{phylon} Phylon Location',
+ showlegend=True,
+ legend=dict(x=1.1, y=0.65), # Position the legend to the right of the last subplot
+ width=figsize[0],
+ height=figsize[1],
+ paper_bgcolor='white',
+ font=dict(
+ size=14 # Adjust the font size as needed
+ ),
+ plot_bgcolor='white'
+ )
+
+ # Update each subplot individually to hide the angular axis and make the center white
+ for i in range(1, len(figures) + 1):
+ fig.update_polars(
+ radialaxis=dict(visible=False),
+ angularaxis=dict(visible=False),
+ bgcolor="white"
+ )
+
+ if save_path:
+ fig.write_html(save_path)
+ if show:
+ fig.show()
+
+
+def plot_circular_genome_combined_with_eggnog_and_unique_genes(list1, list2, list3, title, strain, df_eggnog, show_legend=False):
+ # Ensure the gene list starts with the number 1
+ if 1 in list1:
+ while list1[0] != 1:
+ list1 = list1[-1:] + list1[:-1]
+
+ # Calculate the number of genes
+ num_genes = len(list1)
+
+ # Define the angles for each gene, ensuring gene '1' is at the top center (90 degrees)
+ angles = [(-i * 360 / num_genes + 90) % 360 for i in range(num_genes)]
+
+ # Define the radial range for the bars
+ inner_radius = 0.8
+ outer_radius = 1.0
+
+ outer_gap = 0.1 # Define the gap size between the inner and outer rings
+
+ # Adjust base radius to create a gap
+ outer_base_radius = outer_radius + outer_gap
+
+ # Calculate the height of each bar
+ bar_height = outer_radius - inner_radius
+
+ # Create lists for the radii of each type of gene
+ r_black = [bar_height] * num_genes
+ r_blue = [bar_height if gene in list2 and not isinstance(gene, int) else 0 for gene in list1]
+ r_red = [bar_height if isinstance(gene, int) else 0 for gene in list1]
+ r_unique = [bar_height if gene in list3 and not isinstance(gene, int) else 0 for gene in list1]
+
+ # Create hover text for each gene
+ hover_text = []
+ for gene in list1:
+ if gene in df_eggnog.index:
+ gene_info = df_eggnog.loc[gene]
+ hover_text.append(
+ f"{gene}
COG_category: {gene_info['COG_category']}
Preferred_name: {gene_info['Preferred_name']}
PFAMs: {gene_info['PFAMs']}
BiGG_Reaction: {gene_info['BiGG_Reaction']}"
+ )
+ else:
+ hover_text.append(f"{gene}
COG_category: N/A
Preferred_name: N/A
PFAMs: N/A
BiGG_Reaction: N/A")
+
+ # Create the plot
+ fig = go.Figure()
+
+ fig.add_trace(go.Barpolar(
+ r=r_black,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius] * num_genes,
+ marker_color=background_gene_color,
+ marker_line_color=background_gene_color,
+ opacity=0.7,
+ name='Genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Add blue bars for genes in list2
+ fig.add_trace(go.Barpolar(
+ r=r_blue,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius if gene in list2 and not isinstance(gene, int) else 0 for gene in list1],
+ marker_color=phylon_gene_location,
+ marker_line_color=phylon_gene_location,
+ opacity=0.7,
+ name=f'{title} Phylon genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Add red bars for anchor genes (integers)
+ fig.add_trace(go.Barpolar(
+ r=r_red,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius if isinstance(gene, int) else 0 for gene in list1],
+ marker_color=anchor_gene_color,
+ marker_line_color=anchor_gene_color,
+ opacity=0.7,
+ name='Anchor genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Add unique bars for genes in list3
+ fig.add_trace(go.Barpolar(
+ r=r_unique,
+ theta=angles,
+ width=[360 / num_genes] * num_genes,
+ base=[inner_radius if gene in list3 and not isinstance(gene, int) else 0 for gene in list1],
+ marker_color=unique_gene_color,
+ marker_line_color=unique_gene_color,
+ opacity=0.7,
+ name=f'Unique {title} Phylon genes',
+ text=hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Filter out non-integer genes for the outer ring
+ int_genes = [gene for gene in list1 if isinstance(gene, int)]
+ num_int_genes = len(int_genes)
+
+ if num_int_genes > 0:
+ # Calculate segment angles and widths based on positions of integer genes
+ segment_base_angles = []
+ segment_widths = []
+ outer_colors = []
+ outer_hover_text = []
+
+ for i in range(num_int_genes):
+ current_gene = int_genes[i]
+ next_gene = int_genes[(i + 1) % num_int_genes]
+
+ current_gene_index = list1.index(current_gene)
+ next_gene_index = list1.index(next_gene)
+
+ if next_gene_index < current_gene_index:
+ next_gene_index += num_genes # handle wrapping around the list
+
+ # Calculate the angular width of the segment
+ segment_width = (next_gene_index - current_gene_index) * (360 / num_genes)
+ segment_widths.append(segment_width)
+
+ # The base angle is the angle of the current gene
+ base_angle = angles[current_gene_index] - segment_width / 2
+ segment_base_angles.append(base_angle)
+
+ # Determine the color of the segment
+ if (current_gene < next_gene) or (i == num_int_genes - 1 and current_gene > next_gene):
+ outer_colors.append(Normal_Order_color) # Increasing order
+ else:
+ outer_colors.append(Inversion_Color) # Decreasing order
+
+ outer_hover_text.append(f'{current_gene}-{next_gene}')
+
+ fig.add_trace(go.Barpolar(
+ r=[bar_height] * num_int_genes,
+ theta=segment_base_angles,
+ width=segment_widths,
+ base=[outer_base_radius] * num_int_genes,
+ marker_color=outer_colors,
+ marker_line_color=outer_colors,
+ opacity=0.7,
+ name='Outer Ring',
+ text=outer_hover_text,
+ hoverinfo='text',
+ showlegend=show_legend
+ ))
+
+ # Update layout to add a white circle in the center
+ fig.update_layout(
+ polar=dict(
+ radialaxis=dict(visible=False, range=[0, outer_base_radius + bar_height]),
+ angularaxis=dict(visible=False),
+ bgcolor="white"
+ ),
+ showlegend=show_legend,
+ paper_bgcolor='white',
+ plot_bgcolor='white'
+ )
+
+ return fig
+
+def plot_combined_circular_genomes_with_variaton_and_unique_genes(figures, titles, phylon, figsize=(1500, 500), save_path=None, dpi=300, show = False):
+ # Create a subplot layout with 1 row and len(figures) columns
+ fig = make_subplots(rows=1, cols=len(figures), subplot_titles=titles, specs=[[{'type': 'polar'}]*len(figures)])
+
+ for i, figure in enumerate(figures):
+ for trace in figure['data']:
+ trace.showlegend = False # Hide legend for individual traces
+ fig.add_trace(trace, row=1, col=i+1)
+
+ # Add a single legend manually
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=phylon_gene_location, symbol='square', size=12), name=f'{phylon} Phylon Genes', showlegend=True
+ ))
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=unique_gene_color, symbol='square', size=12), name=f'Unique {phylon} Phylon Genes', showlegend=True
+ ))
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=anchor_gene_color, symbol='square', size=12), name='Anchor Genes', showlegend=True
+ ))
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=background_gene_color, symbol='square', size=12), name='Genes', showlegend=True
+ ))
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color='black', symbol='square', size=0), name='', showlegend=True
+ ))
+
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=Normal_Order_color, symbol='square', size=12), name='No variation', showlegend=True
+ ))
+
+ fig.add_trace(go.Scatterpolar(
+ r=[None], theta=[None], mode='markers', marker=dict(color=Inversion_Color, symbol='square', size=12), name='Inversion', showlegend=True
+ ))
+ # Update the layout
+ fig.update_layout(
+ title = f'{phylon} Phylon Location',
+ showlegend=True,
+ legend=dict(x=1.1, y=0.65), # Position the legend to the right of the last subplot
+ width=figsize[0],
+ height=figsize[1],
+ paper_bgcolor='white',
+ font=dict(
+ size=14 # Adjust the font size as needed
+ ),
+ plot_bgcolor='white'
+ )
+
+ # Update each subplot individually to hide the angular axis and make the center white
+ for i in range(1, len(figures) + 1):
+ fig.update_polars(
+ radialaxis=dict(visible=False),
+ angularaxis=dict(visible=False),
+ bgcolor="white"
+ )
+
+ if save_path:
+ fig.write_html(save_path)
+ if show:
+ fig.show()
diff --git a/pyphylon/plotting_util.py b/pyphylon/plotting_util.py
index e69de29..b09e836 100644
--- a/pyphylon/plotting_util.py
+++ b/pyphylon/plotting_util.py
@@ -0,0 +1,813 @@
+import logging
+import re
+from io import StringIO
+import pandas as pd
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+import gzip
+import pickle
+from tqdm.notebook import tqdm, trange
+from IPython.display import display, HTML
+import plotly.graph_objects as go
+
+
+
+def _get_attr(attributes, attr_id, ignore=False):
+ """
+ Helper function for parsing GFF annotations
+
+ Parameters
+ ----------
+ attributes : str
+ Attribute string
+ attr_id : str
+ Attribute ID
+ ignore : bool
+ If true, ignore errors if ID is not in attributes (default: False)
+
+ Returns
+ -------
+ str, optional
+ Value of attribute
+ """
+
+ try:
+ return re.search(attr_id + "=(.*?)(;|$)", attributes).group(1)
+ except AttributeError:
+ if ignore:
+ return None
+ else:
+ raise ValueError("{} not in attributes: {}".format(attr_id, attributes))
+
+# Need to be updated for seperation of plasmid/chromosome
+
+def gff2pandas(gff_file, feature=["CDS"], index=None):
+ """
+ Converts GFF file(s) to a Pandas DataFrame
+ Parameters
+ ----------
+ gff_file : str or list
+ Path(s) to GFF file
+ feature: str or list
+ Name(s) of features to keep (default = "CDS")
+ index : str, optional
+ Column or attribute to use as index
+
+ Returns
+ -------
+ df_gff: ~pandas.DataFrame
+ GFF formatted as a DataFrame
+ """
+
+ # Argument checking
+ if isinstance(gff_file, str):
+ gff_file = [gff_file]
+
+ if isinstance(feature, str):
+ feature = [feature]
+
+ result = []
+
+ for gff in gff_file:
+ with open(gff, "r") as f:
+ lines = f.readlines()
+
+ # Get lines to skip
+ skiprow = sum([line.startswith("#") for line in lines]) - 2
+
+ # Read GFF
+ names = [
+ "accession",
+ "source",
+ "feature",
+ "start",
+ "end",
+ "score",
+ "strand",
+ "phase",
+ "attributes",
+ ]
+ DF_gff = pd.read_csv(gff, sep="\t", skiprows=skiprow, names=names, header=None, low_memory=False)
+
+ region = DF_gff[DF_gff.feature == 'region']
+ region_len = int(region.iloc[0].end)
+
+ oric = 0
+ # try:
+ # oric = list(DF_gff[DF_gff.feature == 'oriC'].start)[0]
+ # except:
+ # oric = [0]
+
+ # Filter for CDSs
+ DF_cds = DF_gff[DF_gff.feature.isin(feature)]
+
+ # Sort by start position
+ DF_cds = DF_cds.sort_values("start")
+
+ # Extract attribute information
+ DF_cds["locus_tag"] = DF_cds.attributes.apply(_get_attr, attr_id="locus_tag")
+
+ result.append(DF_cds)
+
+ DF_gff = pd.concat(result)
+
+ if index:
+ if DF_gff[index].duplicated().any():
+ logging.warning("Duplicate {} detected. Dropping duplicates.".format(index))
+ DF_gff = DF_gff.drop_duplicates(index)
+ DF_gff.set_index("locus_tag", drop=True, inplace=True)
+
+ return DF_gff[['start', 'end', 'locus_tag']], region_len, oric
+
+def h2a(x, header_to_allele):
+ """
+ Transforms a given locus tag using the header_to_allele dictionary.
+
+ Parameters:
+ x (str): The locus tag to be transformed.
+ header_to_allele (dict): A dictionary mapping locus tags to allele strings.
+
+ Returns:
+ str or None: Transformed locus tag prefixed with 'A', or None if an error occurs.
+ """
+ try:
+ return 'A' + header_to_allele[x].split('A')[1]
+ except:
+ return None
+
+def generate_strain_vectors(path_to_data, metadata):
+ """
+ Generates a dictionary of gene orders for each strain based on GFF3 files.
+
+ Parameters:
+ path_to_data (str): The base directory path where the data is stored.
+ metadata (DataFrame): A DataFrame containing metadata, which includes the genome_id of each strain.
+ header_to_allele (dict): A dictionary mapping locus tags to allele strings.
+
+ Returns:
+ dict: A dictionary where keys are strain names and values are lists of genes in order.
+ """
+
+ strain_vectors = {}
+
+ for strain in tqdm(metadata.genome_id):
+ try:
+ DF_gff, size, oric = gff2pandas(f'{path_to_data}/processed/bakta/{strain}/{strain}.gff3')
+
+ DF_gff['gene'] = DF_gff.locus_tag.apply(lambda x: h2a(x))
+ DF_gff = DF_gff[['gene', 'start']]
+ gene_order = DF_gff.sort_values('start').gene.to_list()
+
+ strain_vectors[strain] = gene_order
+ except Exception as e:
+ print(f"Error processing strain {strain}: {e}")
+
+ return strain_vectors
+
+def plot_gene_length_distribution(strain_vectors):
+ """
+ Plots a histogram showing the distribution of gene lengths for given strain vectors.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ """
+ # Collect lengths of gene lists
+ gene_lengths = [len(genes) for genes in strain_vectors.values()]
+
+ # Creating the histogram
+ plt.hist(gene_lengths, bins=10, color='blue', edgecolor='black')
+
+ # Adding titles and labels
+ plt.title('Distribution of Gene Lengths')
+ plt.xlabel('Gene Length')
+ plt.ylabel('Frequency')
+
+ # Display the histogram
+ plt.show()
+
+def count_common_gene_appearances(strain_vectors):
+ """
+ Counts the occurrences of common genes across different strains.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ DataFrame: A DataFrame where rows are strains and columns are common genes,
+ with each cell representing the count of a gene in the respective strain.
+ """
+
+ # Create a set of all genes in the first strain
+ common_genes = set(strain_vectors[next(iter(strain_vectors))])
+
+ # Find intersection of genes in all strains to get common genes
+ for genes in strain_vectors.values():
+ common_genes.intersection_update(genes)
+
+ # Prepare data for DataFrame: count occurrences of each common gene in each strain
+ data = {gene: [] for gene in common_genes}
+ strains = []
+
+ for strain, genes in strain_vectors.items():
+ strains.append(strain)
+ gene_count = {gene: genes.count(gene) for gene in common_genes}
+ for gene in common_genes:
+ data[gene].append(gene_count[gene])
+
+ # Create the DataFrame
+ df = pd.DataFrame(data, index=strains)
+
+ return df
+
+def find_once_genes(strain_vectors):
+ """
+ Finds genes that appear exactly once in each strain and returns the common and once-only genes.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ tuple: The number of common genes, the number of genes appearing exactly once in each strain, and a set of those genes.
+ """
+ # Finding intersection of all strains
+ common_genes = set(strain_vectors[next(iter(strain_vectors))]) # Start with the first strain's genes
+ for genes in strain_vectors.values():
+ common_genes.intersection_update(genes)
+
+ # Check for genes that appear exactly once in each strain
+ once_genes = set()
+ all_strains_genes = list(strain_vectors.values())
+ first_strain_genes = all_strains_genes[0]
+
+ # Only add genes to the consistent set if they appear exactly once in every strain
+ for gene in common_genes:
+ if all(genes.count(gene) == 1 for genes in all_strains_genes):
+ once_genes.add(gene)
+
+ return len(common_genes), len(once_genes), once_genes
+
+def reorder_genes_by_strain(strain_vectors, genes, strain_name):
+ """
+ Reorders a list of genes based on their appearance in a specified strain.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+ genes (list): A list of genes to be reordered.
+ strain_name (str): The identifier of the strain to use for ordering.
+
+ Returns:
+ list: The reordered list of genes or an error message if the strain is not found.
+ """
+ # Check if the specified strain exists in the strain_vectors dictionary
+ if strain_name not in strain_vectors:
+ return f"Strain '{strain_name}' not found."
+
+ # Retrieve the gene list for the specified strain
+ gene_list = strain_vectors[strain_name]
+
+ # Create a dictionary to find the index of each gene in the strain
+ gene_index_map = {gene: gene_list.index(gene) for gene in gene_list if gene in genes}
+
+ # Sort the genes by their index in the strain using the gene_index_map
+ ordered_genes = sorted(genes, key=lambda gene: gene_index_map.get(gene, float('inf')))
+
+ return ordered_genes
+
+def rearrange_genes(gene_list, target_gene):
+ """
+ Rearranges a list of genes such that the target gene is the first element.
+
+ Parameters:
+ gene_list (list): The list of genes to be rearranged.
+ target_gene (str): The gene to position as the first element in the rearranged list.
+
+ Returns:
+ list: The rearranged list of genes, or the original list if the target gene is not found.
+ """
+ # Check if the target_gene exists in the list
+ if target_gene in gene_list:
+ # Find the index of the target_gene
+ index = gene_list.index(target_gene)
+ # Rearrange: genes after the target_gene come first, then target_gene, then genes before target_gene
+ rearranged_list = gene_list[index:] + gene_list[:index]
+ return rearranged_list
+ else:
+ # Return the original list if target_gene is not found
+ return gene_list
+
+def standardize_strain_orders(strain_vectors, consistent_order_genes, reference_strain_name):
+ """
+ Standardizes the gene order across strains to match a reference strain.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+ consistent_order_genes (list): A list of genes to use for ordering.
+ reference_strain_name (str): The identifier of the reference strain.
+
+ Returns:
+ tuple: The updated strain vectors, the number of strains flipped, the list of problem strains, and the list of updated strains.
+ """
+ # Get the reference strain's name and gene list
+ reference_strain_name = reference_strain_name
+ reference_strain_genes = strain_vectors[reference_strain_name]
+
+ # Order genes in the reference strain according to the consistent_order_genes
+ reference_ordered_genes = reorder_genes_by_strain(strain_vectors, consistent_order_genes, reference_strain_name)
+
+ # Initialize a counter for the number of strains flipped
+ count = 0
+ strain_vectors_updated = {}
+ problem_strains = []
+
+ # Adjust each strain to match the reference order
+ for strain_name, genes in strain_vectors.items():
+ # Reorder genes in the current strain
+ current_ordered_genes = reorder_genes_by_strain(strain_vectors, consistent_order_genes, strain_name)
+ current_ordered_genes_1 = rearrange_genes(reorder_genes_by_strain(strain_vectors, consistent_order_genes, strain_name), reference_ordered_genes[0])
+ current_ordered_genes_2 = rearrange_genes(reorder_genes_by_strain(strain_vectors, consistent_order_genes, strain_name), reference_ordered_genes[-1])
+
+ # Check if current order matches the reference order or its reverse
+ if current_ordered_genes_1 == reference_ordered_genes or current_ordered_genes_2 == reference_ordered_genes:
+ strain_vectors_updated[strain_name] = genes
+ count += 1
+ continue # This strain is already correctly ordered
+ elif current_ordered_genes_1 == reference_ordered_genes[::-1] or current_ordered_genes_2 == reference_ordered_genes[::-1]:
+ strain_vectors_updated[strain_name] = genes[::-1]
+ count += 1
+ continue
+ else:
+ problem_strains.append(strain_name)
+ continue
+
+ return strain_vectors_updated, count, problem_strains, list(strain_vectors_updated.keys())
+
+def create_strain_groups(strain_vectors_filtered, once_genes, starting_strain):
+ """
+ Groups strains based on consistent gene orders starting from a specified strain.
+
+ Parameters:
+ strain_vectors_filtered (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+ once_genes (list): A list of genes that appear exactly once in each strain.
+ starting_strain (str): The identifier of the strain to start the grouping process.
+
+ Returns:
+ dict: A dictionary where keys are group identifiers and values are lists of strains in each group.
+ """
+ # Initialize variables
+ groups = {}
+ all_consistent_strains = set()
+
+ # Start with the first strain
+ current_strain = starting_strain
+ group_number = 1
+
+ while True:
+ # Run the standardization function
+ _, _, problem_strains, consistent_strains = standardize_strain_orders(
+ strain_vectors_filtered, once_genes, current_strain)
+
+ # Add the group to the dictionary
+ group_key = f'strain_group_{group_number}'
+ groups[group_key] = consistent_strains
+ all_consistent_strains.update(consistent_strains)
+
+ # Print the current group and the number of strains it contains
+ print(f" {group_key}: {len(consistent_strains)} strains.")
+
+ # Find a new strain from those not yet in all_consistent_strains
+ remaining_strains = set(strain_vectors_filtered.keys()) - all_consistent_strains
+ if not remaining_strains:
+ break # Exit if there are no more strains to process
+
+ # Pick a new strain to use as the next starting point
+ next_strain = next(iter(remaining_strains), None)
+ if next_strain is None:
+ break # No new strain found to differentiate the strains
+
+ current_strain = next_strain
+ group_number += 1
+
+ return groups
+
+def update_strain_vector(reference_ordered_genes, strain_vectors_filtered):
+ """
+ Updates strain vectors by mapping genes to their positions in a reference ordered gene list.
+
+ Parameters:
+ reference_ordered_genes (list): A list of genes in the reference order.
+ strain_vectors_filtered (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ dict: A dictionary with updated strain vectors where genes are replaced by their positions in the reference list.
+ """
+ gene_mapping = {gene: idx for idx, gene in enumerate(reference_ordered_genes, start=1)}
+
+ # Apply the mapping to strain_vectors_filtered, keep unmapped genes unchanged
+ updated_strain_vectors = {}
+
+ for strain, genes in strain_vectors_filtered.items():
+ updated_genes = [gene_mapping.get(gene, gene) for gene in genes] # Use .get() to return the gene itself if not found
+ updated_strain_vectors[strain] = updated_genes
+
+ return updated_strain_vectors
+
+def adjust_gene_order(strain_vectors):
+ """
+ Adjusts the gene order in strain vectors by reversing lists that are generally decreasing.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ tuple: A dictionary with adjusted gene orders and a count of how many lists were reversed.
+ """
+ # Function to determine if a list is generally decreasing
+ def is_generally_decreasing(numbers):
+ decreasing_count = sum(x > y for x, y in zip(numbers, numbers[1:] + [numbers[0]]))
+ # Consider it decreasing if more than half of the comparisons are decreasing
+ return decreasing_count > len(numbers) / 2
+
+ final_strain_vectors = {}
+ reversed_count = 0 # Counter for how many lists are reversed
+
+ for strain, genes in strain_vectors.items():
+ # Extract numbers and ignore non-numerical entries
+ numbers = [x for x in genes if isinstance(x, int)]
+ if numbers: # Check if there are any numbers
+ if is_generally_decreasing(numbers):
+ genes.reverse() # Reverse the whole list if numbers are generally decreasing
+ reversed_count += 1 # Increment counter if reversed
+ final_strain_vectors[strain] = genes
+
+ return final_strain_vectors, reversed_count
+
+def reorder_to_start_with_one(strain_vectors):
+ """
+ Reorders genes in strain vectors so that the gene '1' starts first if it is present.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ tuple: A dictionary with reordered strain vectors and a count of how many lists were changed.
+ """
+ strain_vectors_final = {}
+ count_changed = 0 # Initialize counter for changed lists
+
+ for strain, genes in strain_vectors.items():
+ if 1 in genes:
+ index_of_one = genes.index(1)
+ if index_of_one != 0: # Check if '1' is not already the first element
+ # Rotate the list so that '1' starts first, and the part before '1' goes to the end
+ reordered_genes = genes[index_of_one:] + genes[:index_of_one]
+ strain_vectors_final[strain] = reordered_genes
+ count_changed += 1 # Increment the counter as the list is changed
+ else:
+ strain_vectors_final[strain] = genes # '1' is already the first, no change needed
+ else:
+ # If '1' is not in the list, keep it unchanged
+ strain_vectors_final[strain] = genes
+
+ return strain_vectors_final, count_changed
+
+def check_strict_sequence(strain_vectors):
+ """
+ Checks if the gene numbers in each strain vector follow a strict sequence [1, 2, 3, ..., max] without any gaps.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ tuple: A dictionary with boolean values indicating whether each strain follows the strict sequence,
+ the count of strains that follow the strict sequence, and the count of strains that do not.
+ """
+ results = {}
+ count_true = 0
+ count_false = 0
+
+ for strain, genes in strain_vectors.items():
+ # Extract only integer entries from the genes list
+ numbers = [x for x in genes if isinstance(x, int)]
+ # Check if the numbers are exactly [1, 2, 3, ..., max(numbers)] in that order
+ if numbers and numbers == list(range(1, max(numbers) + 1)):
+ results[strain] = True
+ count_true += 1
+ else:
+ results[strain] = False
+ count_false += 1
+
+ return results, count_true, count_false
+
+def generate_gene_names(strain_vectors):
+ """
+ Generates descriptive names for genes based on their positions relative to numerical markers in strain vectors.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ DataFrame: A DataFrame where rows are genes and columns are strains,
+ with cells containing the descriptive gene names.
+ """
+ gene_names = {}
+
+ for strain, genes in strain_vectors.items():
+ # Find indices and values of numerical markers
+ number_indices = [i for i, g in enumerate(genes) if isinstance(g, int)]
+ number_values = [g for g in genes if isinstance(g, int)]
+ # Assume circular nature
+ number_indices = number_indices + [len(genes) + ni for ni in number_indices]
+ number_values = number_values + number_values
+
+ # Temporary storage for gene names of the current strain
+ current_names = {}
+
+ for i in range(len(genes)):
+ if not isinstance(genes[i], int): # Process if it's a gene identifier
+ # Find the closest previous and next numbers
+ previous_number_index = max([ni for ni in number_indices if ni < i])
+ next_number_index = min([ni for ni in number_indices if ni > i])
+
+ previous_number = genes[previous_number_index % len(genes)]
+ next_number = genes[next_number_index % len(genes)]
+
+ # Count the genes between the numbers including this one
+ count_before = i - previous_number_index
+ count_after = next_number_index - i
+
+ # Form the new gene name
+ gene_name = f"{previous_number}_{count_before}_{count_after}_{next_number}"
+ current_names[genes[i]] = gene_name
+
+ # Store names with respect to their original gene identifier
+ gene_names[strain] = current_names
+
+ # Create a DataFrame from the dictionary
+ all_genes = sorted(set(g for names in gene_names.values() for g in names if isinstance(g, str)))
+ df = pd.DataFrame(index=all_genes, columns=strain_vectors.keys())
+
+ for strain, names in gene_names.items():
+ for gene, name in names.items():
+ if gene in df.index: # Ensure the gene is part of the index
+ df.at[gene, strain] = name
+
+ return df.fillna('NA')
+
+def count_genes_between_anchor_genes(df, strain):
+ """
+ Counts the genes between numerical anchor genes for a given strain.
+
+ Parameters:
+ df (DataFrame): A DataFrame where rows are genes and columns are strains, with cells containing descriptive gene names.
+ strain (str): The identifier of the strain to process.
+
+ Returns:
+ DataFrame: A DataFrame where rows are anchor gene pairs and columns are the counts of genes between them.
+ """
+ # Extract the column for the strain and remove NA values
+ column_data = pd.DataFrame(df[strain][df[strain] != 'NA']).reset_index().rename(columns={'index': 'Gene'})
+
+ # Extract number pairs and initialize count dictionary
+ counts = {}
+
+ for entry in column_data[strain]:
+ parts = entry.split('_')
+ if len(parts) == 4:
+ number_before = parts[0]
+ number_after = parts[3]
+
+ key = f"{number_before}-{number_after}"
+ if key not in counts:
+ counts[key] = 0
+ counts[key] += 1
+
+ # Convert the dictionary to a DataFrame sorted by the number pairs
+ result_data = [(key, value) for key, value in sorted(counts.items(), key=lambda x: x[0])]
+ result_df = pd.DataFrame(result_data, columns=['Anchor Genes', 'Total Genes Between'])
+
+ return result_df
+
+def create_gene_count_between_anchor_genes_for_all(df):
+ """
+ Creates a dictionary of DataFrames counting genes between anchor genes for all strains.
+
+ Parameters:
+ df (DataFrame): A DataFrame where rows are genes and columns are strains, with cells containing descriptive gene names.
+
+ Returns:
+ dict: A dictionary where keys are strain identifiers and values are DataFrames
+ containing the counts of genes between anchor genes for each strain.
+ """
+ result_dict = {}
+ for column in df.columns:
+ gene_count_between_anchor_genes = count_genes_between_anchor_genes(df, column)
+ result_dict[column] = gene_count_between_anchor_genes
+ return result_dict
+
+def identify_variation(numbers, ordered_numbers):
+ """
+ Identifies the type of genetic variation by comparing two lists of numbers.
+
+ Parameters:
+ numbers (list): A list of integers representing gene positions in the original order.
+ ordered_numbers (list): A list of integers representing gene positions in the expected order.
+
+ Returns:
+ str: The type of variation ('no variation', 'inversion', 'translocation', 'others').
+ """
+ if numbers == ordered_numbers:
+ return 'no variation'
+
+ n = len(numbers)
+ visited = [False] * n
+ for i in range(n):
+ if visited[i] or numbers[i] == ordered_numbers[i]:
+ visited[i] = True
+ continue
+
+ # Start of a segment
+ segment_original = []
+ segment_ordered = []
+ pos = i
+ while pos < n and not visited[pos]:
+ segment_original.append(numbers[pos])
+ segment_ordered.append(ordered_numbers[pos])
+ visited[pos] = True
+ pos = numbers.index(ordered_numbers[pos])
+
+ if segment_original == segment_ordered[::-1]:
+ return 'inversion'
+ elif sorted(segment_original) != segment_ordered:
+ return 'translocation'
+
+ return 'others'
+
+def identify_genetic_variation(strain_vectors):
+ """
+ Identifies the type of genetic variation for each strain by comparing gene positions.
+
+ Parameters:
+ strain_vectors (dict): A dictionary where keys are strain identifiers and values are lists of genes.
+
+ Returns:
+ DataFrame: A DataFrame where each row represents a strain and its identified variation type.
+ """
+ results = []
+
+ for strain, genes in strain_vectors.items():
+ numbers = [x for x in genes if isinstance(x, int)]
+ if not numbers:
+ continue
+ ordered_numbers = list(range(min(numbers), max(numbers) + 1))
+
+ variation_type = identify_variation(numbers, ordered_numbers)
+ results.append([strain, variation_type])
+
+ # Create DataFrame
+ result_df = pd.DataFrame(results, columns=['Strain', 'Variation'])
+
+ # Print the count of each category
+ print(result_df['Variation'].value_counts())
+
+ return result_df
+
+def find_full_matches(phylon_strain_groups_counts, strain_groups):
+ # This will store the matching (row index, column index) pairs
+ full_matches = []
+
+ # Iterate over the columns (strain groups) in the DataFrame
+ for group in phylon_strain_groups_counts.columns:
+ # Get the number of strains in the current group from strain_vectors_final
+ group_strain_count = len(strain_groups[group])
+
+ # Find rows where the count matches the total number of strains in this group
+ matching_phylons = phylon_strain_groups_counts[phylon_strain_groups_counts[group] == group_strain_count].index
+
+ # Collect the (index, column) pairs
+ for phylon in matching_phylons:
+ full_matches.append((phylon, group, group_strain_count))
+
+ return full_matches
+
+def filter_genes_and_strains(gene_mapping_to_anchor_genes, L_binarized, A_binarized, phylon):
+ gene_list = list(L_binarized[phylon][L_binarized[phylon] == 1].index)
+ strain_list = list(A_binarized.loc[phylon][A_binarized.loc[phylon] == 1].index)
+ # Filter the DataFrame to only include specified genes and strains
+ filtered_df = gene_mapping_to_anchor_genes.loc[gene_list, strain_list]
+ return filtered_df
+
+def count_strain_groups(strain_groups, A_binarized):
+ # Initialize an empty DataFrame with the same index as A_binarized and columns based on the keys of strain_vectors_final
+ result_df = pd.DataFrame(index=A_binarized.index, columns=strain_groups.keys())
+
+ # Iterate over each strain group to calculate the counts
+ for group, strains in strain_groups.items():
+ # Find the intersection of strains that are both in the strain group and in A_binarized
+ common_strains = list(set(strains).intersection(A_binarized.columns))
+
+ # Sum across these strains only if they are present (1) in A_binarized
+ result_df[group] = A_binarized[common_strains].apply(lambda row: row[row == 1].count(), axis=1)
+
+ return result_df
+
+def genes_between_anchors(df, anchor1, anchor2):
+ result = {}
+
+ # Iterate through each column (strain)
+ for strain, series in df.items():
+ # Parse the gene location info and filter by anchors
+ for gene, location in series.items():
+ # Check if location is not NaN and not 'NA'
+ if pd.notna(location) and location != 'NA':
+ parts = location.split('_')
+ if len(parts) == 4:
+ start_anchor, num_genes_after, _, end_anchor = parts
+ # Check if the gene is between the specified anchor genes
+ if int(start_anchor) == anchor1 and int(end_anchor) == anchor2:
+ result[gene] = int(num_genes_after)
+
+ # Create a sorted list of genes based on the number of genes after the first anchor gene
+ sorted_genes = sorted(result, key=result.get)
+
+ return sorted_genes
+
+def genes_in_strain_between_anchors(df, strain, anchor1, anchor2):
+ if strain not in df:
+ return [] # Return an empty list if the strain name is not in the DataFrame
+
+ result = {}
+ series = df[strain]
+
+ # Parse the gene location info and filter by anchors
+ for gene, location in series.items():
+ # Check if location is not NaN and not 'NA'
+ if pd.notna(location) and location != 'NA':
+ parts = location.split('_')
+ if len(parts) == 4:
+ start_anchor, num_genes_after, _, end_anchor = parts
+ if int(start_anchor) == anchor1 and int(end_anchor) == anchor2:
+ result[gene] = int(num_genes_after)
+
+ # Create a sorted list of genes based on the number of genes after the first anchor gene
+ sorted_genes = sorted(result, key=result.get)
+
+ return sorted_genes
+
+def count_anchor_gene_pairs(phylon_location):
+ # Initialize a dictionary to store the counts for each gene
+ gene_pair_counts = {}
+
+ # Iterate over each row (gene) in the DataFrame
+ for gene in phylon_location.index:
+ # Initialize a set to store unique anchor gene pairs
+ anchor_pairs = set()
+
+ # Iterate over each strain (column) to extract the anchor genes
+ for strain in phylon_location.columns:
+ value = phylon_location.loc[gene, strain]
+ if pd.isna(value) or value == 'NA':
+ continue # Skip NaN and 'NA' values
+
+ # Split the value to extract anchor genes
+ anchor_1, _, _, anchor_2 = map(int, value.split('_'))
+
+ # Add both possible anchor gene pairs to the set (order doesn't matter)
+ anchor_pairs.add((min(anchor_1, anchor_2), max(anchor_1, anchor_2)))
+
+ # Count the number of unique pairs for this gene
+ gene_pair_counts[gene] = len(anchor_pairs)
+
+ # Create a new DataFrame with the counts
+ count_df = pd.DataFrame.from_dict(gene_pair_counts, orient='index', columns=['Number of possible location'])
+
+ # Remove rows with 0 possibilities
+ count_df = count_df[count_df['Number of possible location'] > 0]
+
+ return count_df.sort_values(by='Number of possible location', ascending=False)
+
+def unique_genes_by_phylon(df: pd.DataFrame) -> dict:
+ '''
+ This function identifies unique genes for each phylon in a L_binarized.
+
+ Parameters:
+ df (pd.DataFrame, L_binarized): A dataframe where columns are phylon names,
+ row indices are gene names,
+ and values are 1 or 0 indicating the presence of the gene in the phylon.
+
+ Returns:
+ dict: A dictionary where keys are phylon names and values are lists of genes
+ that are unique to each phylon.
+ '''
+ unique_genes = {}
+
+ # Iterate through each phylon (column)
+ for phylon in df.columns:
+ # Get genes present in the current phylon
+ genes_in_phylon = df.index[df[phylon] == 1].tolist()
+
+ # Find unique genes by ensuring they are not present in any other phylon
+ unique_genes[phylon] = [gene for gene in genes_in_phylon if df.loc[gene].sum() == 1]
+
+ return unique_genes
\ No newline at end of file