optimal_threshold.pl 9.54 KB
#!/usr/bin/perl -w

# favre@icsi.berkeley.edu
# 2007-04-24
# determine threshold that maximizes fmeasure
# 1) sort posteriors in increasing order
# 2) put a threshold inbetween each value and compute fmeasure
# 3) keep threshold with maximum fmeasure
# you can do the same with nist error rate

$type="boostexter";

# 's' is a sentence boundary, 'n' is the opposite
my @values;
my %total;
$S="s";
$N="n";
$output="";
$output_type="";
$ref_column = undef;
$hyp_column = undef;

while($arg=shift)
{
	if ($arg eq "--type" or $arg eq "-t") {
		$type = shift;
	} elsif($arg eq "--boostexter" or $arg eq "-b") {
		$type = "boostexter";
	} elsif($arg eq "--boostexter-short" or $arg eq "-bs") {
		$type="boostexter_short"; 
	} elsif($arg eq "--crf" or $arg eq "-c") {
		$type="crf"; 
	} elsif($arg eq "--svm" or $arg eq "-s") {
		$type="svm";
	} elsif($arg eq "--raw" or $arg eq "-r") {
		$type="raw";
	} elsif($arg eq "--output" or $arg eq "-o") {
		$output_type=shift;
	} elsif($arg eq "--csv") {
		$output="csv";
	} elsif($arg eq "--ref") {
		$ref_column = shift;
	} elsif($arg eq "--hyp") {
		$hyp_column = shift;
	} elsif($arg eq "--class") {
		$S = shift;
	} elsif($arg eq "--background") {
		$N = shift;
	} elsif($arg eq "--help" or $arg eq "-h") {
		die("USAGE: $0 [(-o|--output) (fmeasure|nist|epr|precision|recall|decisions)] [--csv] [--ref <int>] [--hyp <int>] [--class <text>] [--background <text>] (--type|-t) (boostexter|boostexter_short|crf|svm|raw) || (-b|--boostexter|-bs|--boostexter-short|-c|--crf|-s|--svm|-r|--raw) [threshold]+");
	} else {
		push @thresholds, $arg;
	}
}
%possible_outputs=map{$_=>1}('fmeasure','nist','epr','precision','recall','','decisions');
exists $possible_outputs{$output_type} or die("ERROR: unknown output type \"$output_type\"");

# read data and put them in an array of pairs (value,label)
if($type eq "boostexter")
{
	while(<>)
	{
		chomp();
		if(/^correct label = (.)/)
		{
			$label=$1;
		}
		elsif(/(\S+) : s $/)
		{
			$value=$1;
#print "$label $value\n";
			push @values,[1*$value,$label];
		}
	}
}
elsif($type eq "boostexter_short")
{
	$S="1";
	$N="0";
	if(not defined $ref_column) { $ref_column = 1 }
	if(not defined $hyp_column) { $hyp_column = 3 }
	while(<>)
	{
		chomp();
		my @parts = split(" ");

		push @values,[1*$parts[$hyp_column],$parts[$ref_column]];
	}
}
elsif($type eq "crf")
{
	while(<>)
	{
		chomp();
		my @parts=split(/\s+/);
		$#parts>3 or next;
        $value = undef;
        for $i(0 .. $#parts) {
    		($value)=$parts[$#parts - $i]=~/^$S\/(.*)/;
            defined $value and last;
        }
        defined $value or die("ERROR: could not find class posterior \"$_\"");
		my $label=$parts[$#parts-3];
		$total{$label}++;
#print "$label $value\n";
		push @values,[1*$value,$label];
	}
}
elsif($type eq "svm")
{
	$S="1";
	$N="-1";
	while(<>)
	{
		chomp();
		my @parts=split(/\s+/);
		my $value=$parts[0];
		my $label=$parts[1];
		$total{$label}++;
#print "$label $value\n";
		push @values,[1*$value,$label];
	}
}
elsif($type eq "raw")
{
	while(<>)
	{
		chomp();
		my @parts=split(/\s+/);
		my $value=$parts[0];
		my $label=$parts[1];
		$total{$label}++;
#print "$label $value\n";
		push @values,[1*$value,$label];
	}
}
else
{
	die("ERROR: unknown file type [$type]");
}

scalar(@values) == 0 and die("ERROR: not enough values");

# sort according to posteriors
@values=sort{$a->[0]<=>$b->[0]}@values;

$output eq "csv" and print STDERR "type, nist, fmeasure, recall, precision, p-r diff, threshold\n";

my $nb_err = 0;
my $nb_ok = 0;
my $nceAcu1=0;
my $nceAcu2=0;
my $NCE=0;

my $CorrectDetectedErr=0;


for my $threshold(@thresholds)
{
	my ($tp,$fp,$tn,$fn)=(0) x 4;
	
	$nceAcu1=0;
	$nceAcu2=0;
	$NCE=0;
	$nb_err = 0;
	$nb_ok = 0;

	

	for my $value(@values)
	{
		$value->[1] eq $S and $nb_ok++ and $nceAcu1+=log($value->[0]);
		$value->[1] eq $N and $nb_err++ and $nceAcu2+=log(1-$value->[0]);
		
		
		

		$value->[0]>=$threshold and $value->[1] eq $S and $tp++;
		$value->[0]>=$threshold and $value->[1] eq $N and $fn++;
		$value->[0]<$threshold and $value->[1] eq $S and $fp++;
		$value->[0]<$threshold and $value->[1] eq $N and $tn++;
	}
	my $Hmax=-$nb_ok*log( $nb_ok/($nb_ok + $nb_err) ) - $nb_err * log (1 - $nb_ok/($nb_ok + $nb_err));


	$NCE=($Hmax+$nceAcu1+$nceAcu2)/$Hmax;
	my $cer = ($fp+$fn)/($nb_ok+$nb_err)*100;
	my $ca  = ($tp+$tn)/($nb_ok+$nb_err)*100;
	
	my $perr=$tn / $nb_err *100;
	my $pok = $tp / $nb_ok *100;

	print "\n --> $threshold :\n";
	print "Errors detected : $perr % - not detected : $fn\n";
	print "OK     detected : $pok % - not detected : $fp\n";

	print "NCE : $NCE\n";
	print "CER : $cer and CA : $ca \n";
	my ($fmeasure,$recall,$precision)=&fmeasure($tp,$fp,$tn,$fn);
	my $nist=&nist($tp,$fp,$tn,$fn);
	$output_type eq "fmeasure" and print "".(100*$fmeasure)."\n";
	$output_type eq "nist" and print "".(100*$nist)."\n";
	$output_type eq "epr" and print "".(100*abs($recall-$precision))."\n";
	$output_type eq "precision" and print "".(100*$precision)."\n";
	$output_type eq "recall" and print "".(100*$recall)."\n";
	$output ne "csv" and printf STDERR "perf:         nist %.2f fmeasure %.2f recall %.2f precision %.2f at $threshold\n",100*$nist,100*$fmeasure,100*$recall,100*$precision;
	$output eq "csv" and print STDERR join(",","at threshold",100*$nist,100*$fmeasure,100*$recall,100*$precision,100*abs($recall-$precision),$threshold)."\n";
}
$output_type ne "" and $output_type ne "decisions" and exit 0;

# compute fmeasure from true positives, false positives, true negatives and false negatives
sub fmeasure()
{
	my($tp,$fp,$tn,$fn)=@_;
	$tp+$fp==0 and return (0,0,0);
	my $r=$tp/($tp+$fp);
	$tp+$fn==0 and return (0,0,0);
	my $p=$tp/($tp+$fn);
	$p+$r==0 and return (0,0,0);
	return (2*$r*$p/($r+$p),$r,$p);
}

# compute the nist error rate from the same parameters
sub nist()
{
	my($tp,$fp,$tn,$fn)=@_;
	$tp+$fn==0 and return 0;
	#return ($fn+$fp)/($tp+$fn);
	return ($fn+$fp)/($tp+$fp);
}

# compute the equal precision and recall (similar to EER) => the prior is reproduced
sub epr()
{
	my($tp,$fp,$tn,$fn)=@_;
	$tp+$fp==0 and return 0;
	my $r=$tp/($tp+$fp);
	$tp+$fn==0 and return 0;
	my $p=$tp/($tp+$fn);
	return abs($r-$p);
}

# find the optimal threshold by crowling the sorted posterior array ascendently
# parameters: total number of s, total number of n, array of indexes of candidates
sub find_threshold()
{
	my $total_s=shift;
	$total_s==0 and return;
	my $total_n=shift;
	$total_n==0 and return;
	my @candidates=@_;
	my %count=($S=>0,$N=>0);
	my %fmeasure_scoring;
	my %nist_scoring;
	my %epr_scoring;
	my $threshold=$values[0]->[0]-1;
	for my $i(0 .. $#candidates)
	{
		my $id=$candidates[$i];
# update error counts and get maxmum performance threshold
		if($threshold<$values[$id]->[0])
		{
			$threshold=$values[$id]->[0];
			my $tp=$total_s-$count{$S};
			my $fp=$count{$S};
			my $tn=$count{$N};
			my $fn=$total_n-$count{$N};
			my ($fmeasure,$recall,$precision)=&fmeasure($tp,$fp,$tn,$fn);
			my $nist=&nist($tp,$fp,$tn,$fn);
			my $epr=&epr($tp,$fp,$tn,$fn);
			if(!defined $fmeasure_scoring{fmeasure} or $fmeasure>$fmeasure_scoring{fmeasure})
			{
				$fmeasure_scoring{fmeasure}=$fmeasure;
				$fmeasure_scoring{threshold}=$threshold;
				$fmeasure_scoring{recall}=$recall;
				$fmeasure_scoring{precision}=$precision;
				$fmeasure_scoring{nist}=$nist;
				$fmeasure_scoring{epr}=$epr;
			}
			if(!defined $nist_scoring{nist} or $nist<$nist_scoring{nist})
			{
				$nist_scoring{nist}=$nist;
				$nist_scoring{threshold}=$threshold;
				$nist_scoring{recall}=$recall;
				$nist_scoring{precision}=$precision;
				$nist_scoring{fmeasure}=$fmeasure;
				$nist_scoring{epr}=$epr;
			}
			if((!defined $epr_scoring{epr} or $epr<$epr_scoring{epr}) and defined $recall and $recall!=0 and defined $precision and $precision!=0)
			{
				$epr_scoring{nist}=$nist;
				$epr_scoring{threshold}=$threshold;
				$epr_scoring{recall}=$recall;
				$epr_scoring{precision}=$precision;
				$epr_scoring{fmeasure}=$fmeasure;
				$epr_scoring{epr}=$epr;
			}
		}
		$count{$values[$id]->[1]}++;
	}
    for $type('nist','fmeasure','recall','precision','epr','threshold') {
        if(not defined $fmeasure_scoring{$type}) { $fmeasure_scoring{$type} = -1.0 }
        if(not defined $nist_scoring{$type}) { $nist_scoring{$type} = -1.0 }
        if(not defined $epr_scoring{$type}) { $epr_scoring{$type} = -1.0 }
    }
	if($output eq "csv")
	{
		print STDERR join(",","max fmeasure",(map{100*$fmeasure_scoring{$_}}('nist','fmeasure','recall','precision','epr')),$fmeasure_scoring{threshold})."\n";
		print STDERR join(",","min nist",(map{100*$nist_scoring{$_}}('nist','fmeasure','recall','precision','epr')),$nist_scoring{threshold})."\n";
		print STDERR join(",","equal P R",(map{100*$epr_scoring{$_}}('nist','fmeasure','recall','precision','epr')),$epr_scoring{threshold})."\n";
	}
	else
	{
		printf STDERR "max_fmeasure: nist %.2f fmeasure %.2f recall %.2f precision %.2f epr %.2f at $fmeasure_scoring{threshold}\n",
			map{100*$fmeasure_scoring{$_}}('nist','fmeasure','recall','precision','epr');
		printf STDERR "min_nist:     nist %.2f fmeasure %.2f recall %.2f precision %.2f epr %.2f at $nist_scoring{threshold}\n",
			map{100*$nist_scoring{$_}}('nist','fmeasure','recall','precision','epr');
		printf STDERR "equal_p_r:    nist %.2f fmeasure %.2f recall %.2f precision %.2f epr %.2f at $epr_scoring{threshold}\n",
			map{100*$epr_scoring{$_}}('nist','fmeasure','recall','precision','epr');
	}
	return ($fmeasure_scoring{threshold},$nist_scoring{threshold});
}

my @candidates=(0 .. $#values);
map{
	if($values[$_]->[1] eq $N){$total_n++}
	if($values[$_]->[1] eq $S){$total_s++}
} @candidates;
($threshold1,$threshold2)=&find_threshold($total_s,$total_n,@candidates);

$output ne "csv" and $output_type eq "" and print "$threshold1 $threshold2\n";