import h5py import scipy.sparse as ss import os import sys import numpy as np import pandas as pd from hicmatrix import HiCMatrix as hm from hicmatrix.lib import MatrixFileHandler from itertools import combinations from scipy.sparse import csr_matrix, dia_matrix, triu, tril, coo_matrix from sklearn.metrics.pairwise import pairwise_distances import pyranges as pr def jac_sim_arr(arr, thresh): rank_n = coo_matrix(arr) arr_thresh = np.where(rank_n.toarray() < thresh, 0, 1) row, col=np.nonzero(arr_thresh) c = np.unique(col) arr = arr_thresh[:,c] #columns whose sum is zero throughout are removed zero_row_indices = np.where(~arr.any(axis=1))[0] #[when the gene row sum is zero, its jaccard similarity is zero] jac_sim = 1 - pairwise_distances(arr, metric = "jaccard") #[calculates the jaccard coefficient for each bin pair based on the , allbins X allbins where values are number of common neighbours] jac_sim[zero_row_indices,:] = 0 return jac_sim def pearson_corr(arr): pearson_matrix = np.corrcoef(arr.toarray()) return pearson_matrix def main(species, genome_info_file, base_folder, input_path, outfile1, outfile2, outfile3, outfile4, chr_list, gene_map_type='tss', percentile=90, coef_type='jac_sim'): # as PyRanges-object gr = pr.read_gtf(genome_info_file) df = gr.df gene_data = df[df['Feature'] == 'gene'] gene_data['gene_id'] = [x.split('.')[0] for x in gene_data['gene_id']] gene_data['gene_id'].drop_duplicates(inplace=True) if species != 'human': gene_data['Chromosome'] = ['chr'+x for x in gene_data['Chromosome']] new = gene_data["Chromosome"].isin(chr_list) gene_data = gene_data[new] gene_data.sort_values(by=['Chromosome', 'Start'], inplace=True) gene_data[['Chromosome', 'Start', 'End', 'gene_id', 'Score' , 'Strand']].to_csv(f'{base_folder}/lohia/hi_c_data_processing/genomes_jlee/{species}.bed', sep = '\t', header=None, index=False) hic = hm.hiCMatrix(input_path) entire_matrix = hic.matrix entire_matrix.setdiag(0, k=0) #if gene_data['gene_id'].tolist()[0][0:4] != 'ENSG': # this is done only for non human # gene_data['Chromosome'] = ['chr'+x for x in gene_data['Chromosome']] if gene_map_type=='tss': max_gene_array=[] gene_index_row = [] gene_id = [] chrom_l = [] for tss_start_x, tss_end_x, gene_x, strand, chrom in zip(gene_data['Start'], gene_data['End'], gene_data['gene_id'], gene_data['Strand'], gene_data['Chromosome']): if chrom in chr_list: (bin_start, bin_end) = hic.getRegionBinRange(chrom, min(tss_start_x, tss_end_x), max(tss_start_x, tss_end_x) ) gene_id.append(gene_x) chrom_l.append(chrom) if strand == '-': gene_index_row.append(bin_end) else: gene_index_row.append(bin_start) gene_by_gene_matrix = entire_matrix[gene_index_row, :][:, gene_index_row] all_bins_length = entire_matrix.shape[0] non_gene_bins = [] non_gene_bin_chrom = [] for given_bin in range(0, all_bins_length): if given_bin not in gene_index_row: non_gene_bins.append(given_bin) non_gene_bin_chrom.append(hic.getBinPos(given_bin)[0]) gene_by_non_gene_bins = entire_matrix[gene_index_row, :][:, non_gene_bins] elif gene_map_type=='max': max_gene_array=[] gene_index_row = [] gene_id = [] chrom_l = [] for tss_start_x, tss_end_x, gene_x, strand, chrom in zip(gene_data['Start'], gene_data['End'], gene_data['gene_id'], gene_data['Strand'], gene_data['Chromosome']): if chrom in chr_list: (bin_start, bin_end) = hic.getRegionBinRange(chrom, min(tss_start_x, tss_end_x), max(tss_start_x, tss_end_x) ) max_gene_array.append(entire_matrix[bin_start: bin_end+1].max(axis=0)) gene_id.append(gene_x) chrom_l.append(chrom) for i in range(bin_start, bin_end+1): gene_index_row.append(i) gene_by_all_bins = csr_matrix(ss.vstack(max_gene_array).T) gene_by_gene_bins = [] for tss_start_x, tss_end_x, gene_x, strand, chrom in zip(gene_data['Start'], gene_data['End'], gene_data['gene_id'], gene_data['Strand'], gene_data['Chromosome']): if chrom in chr_list: (bin_start, bin_end) = hic.getRegionBinRange(chrom, min(tss_start_x, tss_end_x), max(tss_start_x, tss_end_x) ) gene_by_gene_bins.append(gene_by_all_bins[bin_start: bin_end+1].max(axis=0)) gene_by_gene_matrix = ss.vstack(gene_by_gene_bins) all_bins_length = entire_matrix.shape[0] non_gene_bins = [] non_gene_bin_chrom = [] for given_bin in range(0, all_bins_length): if given_bin not in gene_index_row: non_gene_bins.append(given_bin) non_gene_bin_chrom.append(hic.getBinPos(given_bin)[0]) gene_by_non_gene_bins = gene_by_all_bins[non_gene_bins, :].T gene_by_gene_and_non_gene_bins = csr_matrix(ss.hstack([gene_by_gene_matrix, gene_by_non_gene_bins])) gene_by_gene_and_non_gene_bins_chrom = chrom_l + non_gene_bin_chrom jac_list_gene = [] jac_list_non_gene = [] jac_list_all_bins = [] gene_id_interval = [] gene_by_gene_matrix_all = [] chrom_interval = [] for chrom in chr_list: gene_chr = entire_matrix[hic.getChrBinRange(chrom)[0]:hic.getChrBinRange(chrom)[1], hic.getChrBinRange(chrom)[0]:hic.getChrBinRange(chrom)[1]] #threshold is calcuated for each individual chromosme thresh = np.percentile(gene_chr.toarray().flatten(), percentile) #flattening will include 0 values as well gene_id_chr_list = [(gene_id.index(g_id), chrom, g_id) for g_id, g_chr in zip(gene_id, chrom_l) if g_chr == chrom ] non_gene_bin_index_list = [counter for counter, g_chr in enumerate(non_gene_bin_chrom) if g_chr == chrom ] gene_by_gene_and_non_gene_bins_index_list = [counter for counter, g_chr in enumerate(gene_by_gene_and_non_gene_bins_chrom) if g_chr == chrom ] g_index_list, chr_index_list, gene_id_list = zip(*gene_id_chr_list) gene_id_interval.extend(gene_id_list) chrom_interval.extend(chr_index_list) gene_by_gene_matrix_all.append(csr_matrix(gene_by_gene_matrix)[g_index_list, :][:, g_index_list]) if coef_type == 'jac_sim': jac_list_gene.append(jac_sim_arr(csr_matrix(gene_by_gene_matrix)[g_index_list, :][:, g_index_list], thresh)) jac_list_non_gene.append(jac_sim_arr( csr_matrix(gene_by_non_gene_bins)[g_index_list, :][:, non_gene_bin_index_list], thresh)) jac_list_all_bins.append(jac_sim_arr( gene_by_gene_and_non_gene_bins[g_index_list, :][:, gene_by_gene_and_non_gene_bins_index_list], thresh)) else: jac_list_gene.append(pearson_corr(csr_matrix(gene_by_gene_matrix)[g_index_list, :][:, g_index_list])) jac_list_non_gene.append(pearson_corr( csr_matrix(gene_by_non_gene_bins)[g_index_list, :][:, non_gene_bin_index_list])) jac_list_all_bins.append(pearson_corr( gene_by_gene_and_non_gene_bins[g_index_list, :][:, gene_by_gene_and_non_gene_bins_index_list])) gene_cut_intervals = list(zip(chrom_interval, [0 for i in chrom_interval], [1 for i in chrom_interval], gene_id_interval)) matrix_genome_wide = ss.block_diag(jac_list_gene) hic = hm.hiCMatrix() hic.nan_bins = [] hic.setMatrix(matrix_genome_wide, gene_cut_intervals) hic.save(f'{outfile1}_{coef_type}.h5') matrix_genome_wide = ss.block_diag(jac_list_non_gene) hic = hm.hiCMatrix() hic.nan_bins = [] hic.setMatrix(matrix_genome_wide, gene_cut_intervals) hic.save(f'{outfile2}_{coef_type}.h5') matrix_genome_wide = ss.block_diag(jac_list_all_bins) hic = hm.hiCMatrix() hic.nan_bins = [] hic.setMatrix(matrix_genome_wide, gene_cut_intervals) hic.save(f'{outfile3}_{coef_type}.h5') matrix_genome_wide = ss.block_diag(gene_by_gene_matrix_all) hic = hm.hiCMatrix() hic.nan_bins = [] hic.setMatrix(matrix_genome_wide, gene_cut_intervals) hic.save(f'{outfile4}.h5') if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--base_folder', default='/grid/gillis/data/', help='The base folder for rugen vs cluster') parser.add_argument('--resolution_network', default='10kbp_raw', help='resolution in kb format') parser.add_argument('--gene_map_type', default='tss', help='tss or tss_max') parser.add_argument('--chr_list', default='tss', help='tss or tss_max') parser.add_argument('--input_file', default='tss', help='tss or tss_max') parser.add_argument('--outfile1', default='tss', help='tss or tss_max') parser.add_argument('--outfile2', default='tss', help='tss or tss_max') parser.add_argument('--outfile3', default='tss', help='tss or tss_max') parser.add_argument('--outfile4', default='tss', help='tss or tss_max') parser.add_argument('--genome_info_file', default='tss', help='tss or tss_max') parser.add_argument('--percentile', default='90', help='tss or tss_max') parser.add_argument('--species', default='90', help='tss or tss_max') parser.add_argument('--coef_type', default='jac_sim', help='tss or tss_max') args = parser.parse_args() chr_list = args.chr_list.split(',') main(args.species, args.genome_info_file, args.base_folder, args.input_file, args.outfile1, args.outfile2, args.outfile3, args.outfile4, chr_list, gene_map_type=args.gene_map_type, percentile=int(args.percentile), coef_type=args.coef_type)