You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
12 KiB
Python

#a class for the Kaplan-Meier estimator
import numpy as np
from math import sqrt
import matplotlib.pyplot as plt
class KAPLAN_MEIER(object):
def __init__(self, data, timesIn, groupIn, censoringIn):
raise RuntimeError('Newer version of Kaplan-Meier class available in survival2.py')
#store the inputs
self.data = data
self.timesIn = timesIn
self.groupIn = groupIn
self.censoringIn = censoringIn
def fit(self):
#split the data into groups based on the predicting variable
#get a set of all the groups
groups = list(set(self.data[:,self.groupIn]))
#create an empty list to store the data for different groups
groupList = []
#create an empty list for each group and add it to groups
for i in range(len(groups)):
groupList.append([])
#iterate through all the groups in groups
for i in range(len(groups)):
#iterate though the rows of dataArray
for j in range(len(self.data)):
#test if this row has the correct group
if self.data[j,self.groupIn] == groups[i]:
#add the row to groupList
groupList[i].append(self.data[j])
#create an empty list to store the times for each group
timeList = []
#iterate through all the groups
for i in range(len(groupList)):
#create an empty list
times = []
#iterate through all the rows of the group
for j in range(len(groupList[i])):
#get a list of all the times in the group
times.append(groupList[i][j][self.timesIn])
#get a sorted set of the times and store it in timeList
times = list(sorted(set(times)))
timeList.append(times)
#get a list of the number at risk and events at each time
#create an empty list to store the results in
timeCounts = []
#create an empty list to hold points for plotting
points = []
#create a list for points where censoring occurs
censoredPoints = []
#iterate trough each group
for i in range(len(groupList)):
#initialize a variable to estimate the survival function
survival = 1
#initialize a variable to estimate the variance of
#the survival function
varSum = 0
#initialize a counter for the number at risk
riskCounter = len(groupList[i])
#create a list for the counts for this group
counts = []
##create a list for points to plot
x = []
y = []
#iterate through the list of times
for j in range(len(timeList[i])):
if j != 0:
if j == 1:
#add an indicator to tell if the time
#starts a new group
groupInd = 1
#add (0,1) to the list of points
x.append(0)
y.append(1)
#add the point time to the right of that
x.append(timeList[i][j-1])
y.append(1)
#add the point below that at survival
x.append(timeList[i][j-1])
y.append(survival)
#add the survival to y
y.append(survival)
else:
groupInd = 0
#add survival twice to y
y.append(survival)
y.append(survival)
#add the time twice to x
x.append(timeList[i][j-1])
x.append(timeList[i][j-1])
#add each censored time, number of censorings and
#its survival to censoredPoints
censoredPoints.append([timeList[i][j-1],
censoringNum,survival,groupInd])
#add the count to the list
counts.append([timeList[i][j-1],riskCounter,
eventCounter,survival,
sqrt(((survival)**2)*varSum)])
#increment the number at risk
riskCounter += -1*(riskChange)
#initialize a counter for the change in the number at risk
riskChange = 0
#initialize a counter to zero
eventCounter = 0
#intialize a counter to tell when censoring occurs
censoringCounter = 0
censoringNum = 0
#iterate through the observations in each group
for k in range(len(groupList[i])):
#check of the observation has the given time
if (groupList[i][k][self.timesIn]) == (timeList[i][j]):
#increment the number at risk counter
riskChange += 1
#check if this is an event or censoring
if groupList[i][k][self.censoringIn] == 1:
#add 1 to the counter
eventCounter += 1
else:
censoringNum += 1
#check if there are any events at this time
if eventCounter != censoringCounter:
censoringCounter = eventCounter
#calculate the estimate of the survival function
survival *= ((float(riskCounter) -
eventCounter)/(riskCounter))
try:
#calculate the estimate of the variance
varSum += (eventCounter)/((riskCounter)
*(float(riskCounter)-
eventCounter))
except ZeroDivisionError:
varSum = 0
#append the last row to counts
counts.append([timeList[i][len(timeList[i])-1],
riskCounter,eventCounter,survival,
sqrt(((survival)**2)*varSum)])
#add the last time once to x
x.append(timeList[i][len(timeList[i])-1])
x.append(timeList[i][len(timeList[i])-1])
#add the last survival twice to y
y.append(survival)
#y.append(survival)
censoredPoints.append([timeList[i][len(timeList[i])-1],
censoringNum,survival,1])
#add the list for the group to al ist for all the groups
timeCounts.append(np.array(counts))
points.append([x,y])
#returns a list of arrays, where each array has as it columns: the time,
#the number at risk, the number of events, the estimated value of the
#survival function at that time, and the estimated standard error at
#that time, in that order
self.results = timeCounts
self.points = points
self.censoredPoints = censoredPoints
def plot(self):
x = []
#iterate through the groups
for i in range(len(self.points)):
#plot x and y
plt.plot(np.array(self.points[i][0]),np.array(self.points[i][1]))
#create lists of all the x and y values
x += self.points[i][0]
for j in range(len(self.censoredPoints)):
#check if censoring is occuring
if (self.censoredPoints[j][1] != 0):
#if this is the first censored point
if (self.censoredPoints[j][3] == 1) and (j == 0):
#calculate a distance beyond 1 to place it
#so all the points will fit
dx = ((1./((self.censoredPoints[j][1])+1.))
*(float(self.censoredPoints[j][0])))
#iterate through all the censored points at this time
for k in range(self.censoredPoints[j][1]):
#plot a vertical line for censoring
plt.vlines((1+((k+1)*dx)),
self.censoredPoints[j][2]-0.03,
self.censoredPoints[j][2]+0.03)
#if this censored point starts a new group
elif ((self.censoredPoints[j][3] == 1) and
(self.censoredPoints[j-1][3] == 1)):
#calculate a distance beyond 1 to place it
#so all the points will fit
dx = ((1./((self.censoredPoints[j][1])+1.))
*(float(self.censoredPoints[j][0])))
#iterate through all the censored points at this time
for k in range(self.censoredPoints[j][1]):
#plot a vertical line for censoring
plt.vlines((1+((k+1)*dx)),
self.censoredPoints[j][2]-0.03,
self.censoredPoints[j][2]+0.03)
#if this is the last censored point
elif j == (len(self.censoredPoints) - 1):
#calculate a distance beyond the previous time
#so that all the points will fit
dx = ((1./((self.censoredPoints[j][1])+1.))
*(float(self.censoredPoints[j][0])))
#iterate through all the points at this time
for k in range(self.censoredPoints[j][1]):
#plot a vertical line for censoring
plt.vlines((self.censoredPoints[j-1][0]+((k+1)*dx)),
self.censoredPoints[j][2]-0.03,
self.censoredPoints[j][2]+0.03)
#if this is a point in the middle of the group
else:
#calcuate a distance beyond the current time
#to place the point, so they all fit
dx = ((1./((self.censoredPoints[j][1])+1.))
*(float(self.censoredPoints[j+1][0])
- self.censoredPoints[j][0]))
#iterate through all the points at this time
for k in range(self.censoredPoints[j][1]):
#plot a vetical line for censoring
plt.vlines((self.censoredPoints[j][0]+((k+1)*dx)),
self.censoredPoints[j][2]-0.03,
self.censoredPoints[j][2]+0.03)
#set the size of the plot so it extends to the max x and above 1 for y
plt.xlim((0,np.max(x)))
plt.ylim((0,1.05))
#label the axes
plt.xlabel('time')
plt.ylabel('survival')
plt.show()
def show_results(self):
#start a string that will be a table of the results
resultsString = ''
#iterate through all the groups
for i in range(len(self.results)):
#label the group and header
resultsString += ('Group {0}\n\n'.format(i) +
'Time At Risk Events Survival Std. Err\n')
for j in self.results[i]:
#add the results to the string
resultsString += (
'{0:<9d}{1:<12d}{2:<11d}{3:<13.4f}{4:<6.4f}\n'.format(
int(j[0]),int(j[1]),int(j[2]),j[3],j[4]))
print(resultsString)