Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions scripts/griffin_GC_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@


def collect_reads(sublist):
"""
Function to create a dictionary of dictionaries that holds the frequency of each read length and GC content combination.
Additional logic is included to handle single-end sequencing. This version uses sequence length in place of a read's template length attribute
to represent read length.
"""
#create a dict for holding the frequency of each read length and GC content
GC_dict = {}
for length in range(size_range[0],size_range[1]+1):
Expand All @@ -116,7 +121,6 @@ def collect_reads(sublist):
#this might also need to be in the loop
#import the ref_seq
ref_seq=pysam.FastaFile(ref_seq_path)

for i in range(len(sublist)):
chrom = sublist.iloc[i][0]
start = sublist.iloc[i][1]
Expand All @@ -128,7 +132,7 @@ def collect_reads(sublist):
fetched = bam_file.fetch(chrom,start,end)
for read in fetched:
#use both fw (positive template length) and rv (negative template length) reads
if (read.is_reverse==False and read.template_length>=size_range[0] and read.template_length<=size_range[1]) or (read.is_reverse==True and -read.template_length>=size_range[0] and -read.template_length<=size_range[1]):
if (read.is_reverse==False and read.template_length>=size_range[0] and read.template_length<=size_range[1]) or (read.is_reverse==True and -read.template_length>=size_range[0] and -read.template_length<=size_range[1]):
#qc filters, some longer fragments are considered 'improper pairs' but I would like to keep these
if read.is_paired==True and read.mapping_quality>=map_q and read.is_duplicate==False and read.is_qcfail==False:
if read.is_reverse==False:
Expand All @@ -146,18 +150,42 @@ def collect_reads(sublist):
rng = np.random.default_rng(fragment_start)
fragment_seq[np.isin(fragment_seq, ['N','R','Y','K','M','B','D','H','V'])] = rng.integers(2, size=len(fragment_seq[np.isin(fragment_seq, ['N','R','Y','K','M','B','D','H','V'])])) #random integer in range(2) (i.e. 0 or 1)
fragment_seq = fragment_seq.astype(float)


num_GC = int(fragment_seq.sum())
GC_dict[abs(read.template_length)][num_GC]+=1

# Additional logic to handle single-end reads.
elif read.is_paired==False and read.mapping_quality>= map_q and read.is_duplicate==False and read.is_qcfail==False:
if read.is_reverse:
rl = len(read.seq)
fragment_end = read.reference_end + rl
fragment_start= read.reference_end
if not read.is_reverse:
rl = len(read.seq)
fragment_end = read.reference_start + rl
fragment_start= read.reference_start
# check the read length is within the designated size range
if rl >=size_range[0] and rl <= size_range[1]:
fragment_seq = ref_seq.fetch(read.reference_name,fragment_start,fragment_end)
fragment_seq = np.array(list(fragment_seq.upper()))
fragment_seq[np.isin(fragment_seq, ['A','T','W'])] = 0
fragment_seq[np.isin(fragment_seq, ['C','G','S'])] = 1
rng = np.random.default_rng(fragment_start)
fragment_seq[np.isin(fragment_seq, ['N','R','Y','K','M','B','D','H','V'])] = rng.integers(2, size=len(fragment_seq[np.isin(fragment_seq, ['N','R','Y','K','M','B','D','H','V'])])) #random integer in range(2) (i.e. 0 or 1)
fragment_seq = fragment_seq.astype(float)

num_GC = int(fragment_seq.sum())
breakpoint()
GC_dict[abs(rl)][num_GC]+=1

print('done')
return(GC_dict)


# In[ ]:


## Parallel processing applied to the collect_reads function
start_time = time.time()
p = Pool(processes=CPU) #use the available CPU
sublists = np.array_split(mappable_intervals,CPU) #split the list into sublists, one per CPU
Expand All @@ -169,10 +197,11 @@ def collect_reads(sublist):


all_GC_df = pd.DataFrame()
for i,GC_dict in enumerate(GC_dict_list):
for i, GC_dict in enumerate(GC_dict_list):
GC_df = pd.DataFrame()
for length in GC_dict.keys():
current = pd.Series(GC_dict[length]).reset_index()
curr_dict = GC_dict_list.get(GC_dict)
for length in curr_dict.keys():
current = pd.Series(curr_dict[length]).reset_index()
current = current.rename(columns={'index':'num_GC',0:'number_of_fragments'})
current['length']=length
current = current[['length','num_GC','number_of_fragments']]
Expand Down
102 changes: 83 additions & 19 deletions scripts/griffin_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def closest_key(dictionary, i):
start_time = time.time()
all_sites_bed = pybedtools.BedTool.from_dataframe(all_sites[[chrom_column,'fetch_start','fetch_end']])
all_sites_bed = all_sites_bed.sort()
# Merges any overlapping sites
all_sites_bed = all_sites_bed.merge()
print('Intervals to fetch:\t'+str(len(all_sites_bed)))
print('Total bp to fetch:\t'+str(all_sites_bed.total_coverage()))
Expand All @@ -340,6 +341,13 @@ def closest_key(dictionary, i):


def collect_fragments(input_list):
"""
Collects fragments for a single site.
Added logic to handle single end reads. The fragment length is calculated using sequence length.
'Regular' (paired-end read) data uses the template length attribute to get fragment length.
params:
input list = [index, chrom, start, end]
"""
i,chrom,start,end = input_list
#open the bam file for each pool worker (otherwise individual pool workers can close it)
bam_file = pysam.AlignmentFile(bam_path)
Expand All @@ -364,14 +372,70 @@ def collect_fragments(input_list):
########################
for read in fetched:
#filter out reads
if abs(read.template_length)>=sz_range[0] and abs(read.template_length)<=sz_range[1] and read.is_paired==True and read.mapping_quality>=map_q and read.is_duplicate==False and read.is_qcfail==False:
#only use fw reads with positive fragment lengths (negative indicates an abnormal pair)
#all paired end reads have a fw and rv read so we don't need the rv read to find the midpoint.
if read.is_reverse==False and read.template_length>0:
fragment_start = read.reference_start #for fw read, read start is fragment start
fragment_end = read.reference_start+read.template_length
midpoint = int(np.floor((fragment_start+fragment_end)/2))

if read.is_paired:
if abs(read.template_length)>=sz_range[0] and abs(read.template_length)<=sz_range[1] and read.mapping_quality>=map_q and read.is_duplicate==False and read.is_qcfail==False:
#only use fw reads with positive fragment lengths (negative indicates an abnormal pair)
#all paired end reads have a fw and rv read so we don't need the rv read to find the midpoint.
if read.is_reverse==False and read.template_length>0:
fragment_start = read.reference_start #for fw read, read start is fragment start
fragment_end = read.reference_start+read.template_length
midpoint = int(np.floor((fragment_start+fragment_end)/2))

#count the GC content
fragment_seq = ref_seq.fetch(read.reference_name,fragment_start,fragment_end)
fragment_seq = np.array(list(fragment_seq.upper()))
fragment_seq[np.isin(fragment_seq, ['A','T','W'])] = 0
fragment_seq[np.isin(fragment_seq, ['C','G','S'])] = 1
rng = np.random.default_rng(fragment_start)
fragment_seq[np.isin(fragment_seq, ['N','R','Y','K','M','B','D','H','V'])] = rng.integers(2, size=len(fragment_seq[np.isin(fragment_seq, ['N','R','Y','K','M','B','D','H','V'])])) #random integer in range(2) (i.e. 0 or 1)
fragment_seq = fragment_seq.astype(float)

if mappability_correction.lower() == 'true':
#find the two read locations for mappability correction
fw_read_map = mappability.values(chrom,read.reference_start,read.reference_start+read.reference_length)
fw_read_map = np.mean(np.nan_to_num(fw_read_map)) #replace any nan with zero and take the mean

rv_read_map = mappability.values(chrom,read.reference_start+read.template_length-read.reference_length,read.reference_start+read.template_length)
rv_read_map = np.mean(np.nan_to_num(rv_read_map)) #replace any nan with zero and take the mean

#check that the site is in the window
if midpoint>=start and midpoint<end:
#count the fragment
cov_dict[midpoint]+=1

##get the GC bias
read_GC_content = sum(fragment_seq)
read_GC_bias = GC_bias[abs(read.template_length)][read_GC_content]

#count the fragment weighted by GC bias
if not np.isnan(read_GC_bias):
GC_cov_dict[midpoint]+=(1/read_GC_bias)

if mappability_correction.lower() == 'true':
#get the mappability bias
read_map = np.int32(np.round(100*(fw_read_map+rv_read_map)/2))
read_map_bias = mappability_bias[read_map]
GC_map_cov_dict[midpoint]+=(1/read_GC_bias)*(1/read_map_bias)

#print(read_GC_bias,read_map,read_map_bias)

else: #if fragment doesn't fully overlap
continue

del(read,midpoint,fragment_seq)

else: # handling the single-end reads
if len(read.seq)>=sz_range[0] and len(read.seq)<=sz_range[1] and read.mapping_quality>=map_q and read.is_duplicate==False and read.is_qcfail==False:
rl = len(read.seq)
if read.is_reverse == False:
fragment_start = read.reference_start #for fw read, read start is fragment start
fragment_end = read.reference_start+rl
midpoint = int(np.floor((fragment_start+fragment_end)/2))
else:
fragment_start = read.reference_end #for fw read, read start is fragment start
fragment_end = read.reference_end+rl
midpoint = int(np.floor((fragment_start+fragment_end)/2))

#count the GC content
fragment_seq = ref_seq.fetch(read.reference_name,fragment_start,fragment_end)
fragment_seq = np.array(list(fragment_seq.upper()))
Expand All @@ -386,17 +450,17 @@ def collect_fragments(input_list):
fw_read_map = mappability.values(chrom,read.reference_start,read.reference_start+read.reference_length)
fw_read_map = np.mean(np.nan_to_num(fw_read_map)) #replace any nan with zero and take the mean

rv_read_map = mappability.values(chrom,read.reference_start+read.template_length-read.reference_length,read.reference_start+read.template_length)
rv_read_map = mappability.values(chrom,read.reference_start+rl-read.reference_length,read.reference_start+rl)
rv_read_map = np.mean(np.nan_to_num(rv_read_map)) #replace any nan with zero and take the mean

#check that the site is in the window
#check that the site is in the window
if midpoint>=start and midpoint<end:
#count the fragment
cov_dict[midpoint]+=1

##get the GC bias
read_GC_content = sum(fragment_seq)
read_GC_bias = GC_bias[abs(read.template_length)][read_GC_content]
read_GC_bias = GC_bias[abs(rl)][read_GC_content]

#count the fragment weighted by GC bias
if not np.isnan(read_GC_bias):
Expand All @@ -409,15 +473,11 @@ def collect_fragments(input_list):
GC_map_cov_dict[midpoint]+=(1/read_GC_bias)*(1/read_map_bias)

#print(read_GC_bias,read_map,read_map_bias)

else: #if fragment doesn't fully overlap
continue

continue
del(read,midpoint,fragment_seq)

else:
#print('reverse',read.is_reverse)
continue


output = pd.DataFrame(pd.Series(cov_dict, name = 'uncorrected'))
output['GC_corrected'] = pd.Series(GC_cov_dict)
Expand Down Expand Up @@ -452,9 +512,13 @@ def collect_fragments(input_list):
sys.stdout.flush()
start_time = time.time()

## Comment out to run debugging mode:
p = Pool(processes=CPU) #use the specified number of processes
results = p.map(collect_fragments, to_fetch.values, 1) #Send only one interval to each processor at a time.

## Uncomment for debugging mode:
# results = collect_fragments(to_fetch.values[0])

elapsed_time = time.time()-overall_start_time
print('Done with fetch '+str(int(np.floor(elapsed_time/60)))+' min '+str(int(np.round(elapsed_time%60)))+' sec')
del(elapsed_time)
Expand Down