commit
This commit is contained in:
parent
a599e78182
commit
04d9e8221b
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Datasource local storage ignored files
|
||||
/../../../../:\nlp-ner-project\NER-IDCNN-action\.idea/dataSources/
|
||||
/dataSources.local.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
|
@ -0,0 +1,11 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,49 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="chenliming@10.191.78.227:22">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="chenliming@172.30.236.253:22">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="chenliming@172.30.236.253:22 (1)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="chenliming@172.30.236.253:22 (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="chenliming@172.30.236.253:22 (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="chenliming@172.30.236.253:22 (4)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Encoding" defaultCharsetForPropertiesFiles="UTF-8">
|
||||
<file url="PROJECT" charset="UTF-8" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
</profile>
|
||||
</component>
|
|
@ -0,0 +1,4 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/NER-IDCNN-action.iml" filepath="$PROJECT_DIR$/.idea/NER-IDCNN-action.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"num_words": 4412,
|
||||
"word_dim": 100,
|
||||
"num_tags": 13,
|
||||
"seg_dim": 20,
|
||||
"lstm_dim": 100,
|
||||
"batch_size": 120,
|
||||
"optimizer": "adam",
|
||||
"emb_file": "data\\wiki_100.utf8",
|
||||
"clip": 5.0,
|
||||
"dropout_keep": 0.5,
|
||||
"lr": 0.001,
|
||||
"tag_schema": "BIOES",
|
||||
"pre_emb": true,
|
||||
"model_type": "idcnn",
|
||||
"is_train": true
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
#!/usr/bin/perl -w
|
||||
# conlleval: evaluate result of processing CoNLL-2000 shared task
|
||||
# usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
|
||||
# README: http://cnts.uia.ac.be/conll2000/chunking/output.html
|
||||
# options: l: generate LaTeX output for tables like in
|
||||
# http://cnts.uia.ac.be/conll2003/ner/example.tex
|
||||
# r: accept raw result tags (without B- and I- prefix;
|
||||
# assumes one word per chunk)
|
||||
# d: alternative delimiter tag (default is single space)
|
||||
# o: alternative outside tag (default is O)
|
||||
# note: the file should contain lines with items separated
|
||||
# by $delimiter characters (default space). The final
|
||||
# two items should contain the correct tag and the
|
||||
# guessed tag in that order. Sentences should be
|
||||
# separated from each other by empty lines or lines
|
||||
# with $boundary fields (default -X-).
|
||||
# url: http://lcg-www.uia.ac.be/conll2000/chunking/
|
||||
# started: 1998-09-25
|
||||
# version: 2004-01-26
|
||||
# author: Erik Tjong Kim Sang <erikt@uia.ua.ac.be>
|
||||
|
||||
use strict;
|
||||
|
||||
my $false = 0;
|
||||
my $true = 42;
|
||||
|
||||
my $boundary = "-X-"; # sentence boundary
|
||||
my $correct; # current corpus chunk tag (I,O,B)
|
||||
my $correctChunk = 0; # number of correctly identified chunks
|
||||
my $correctTags = 0; # number of correct chunk tags
|
||||
my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
|
||||
my $delimiter = " "; # field delimiter
|
||||
my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
|
||||
my $firstItem; # first feature (for sentence boundary checks)
|
||||
my $foundCorrect = 0; # number of chunks in corpus
|
||||
my $foundGuessed = 0; # number of identified chunks
|
||||
my $guessed; # current guessed chunk tag
|
||||
my $guessedType; # type of current guessed chunk tag
|
||||
my $i; # miscellaneous counter
|
||||
my $inCorrect = $false; # currently processed chunk is correct until now
|
||||
my $lastCorrect = "O"; # previous chunk tag in corpus
|
||||
my $latex = 0; # generate LaTeX formatted output
|
||||
my $lastCorrectType = ""; # type of previously identified chunk tag
|
||||
my $lastGuessed = "O"; # previously identified chunk tag
|
||||
my $lastGuessedType = ""; # type of previous chunk tag in corpus
|
||||
my $lastType; # temporary storage for detecting duplicates
|
||||
my $line; # line
|
||||
my $nbrOfFeatures = -1; # number of features per line
|
||||
my $precision = 0.0; # precision score
|
||||
my $oTag = "O"; # outside tag, default O
|
||||
my $raw = 0; # raw input: add B to every token
|
||||
my $recall = 0.0; # recall score
|
||||
my $tokenCounter = 0; # token counter (ignores sentence breaks)
|
||||
|
||||
my %correctChunk = (); # number of correctly identified chunks per type
|
||||
my %foundCorrect = (); # number of chunks in corpus per type
|
||||
my %foundGuessed = (); # number of identified chunks per type
|
||||
|
||||
my @features; # features on line
|
||||
my @sortedTypes; # sorted list of chunk type names
|
||||
|
||||
# sanity check
|
||||
while (@ARGV and $ARGV[0] =~ /^-/) {
|
||||
if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
|
||||
elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
|
||||
elsif ($ARGV[0] eq "-d") {
|
||||
shift(@ARGV);
|
||||
if (not defined $ARGV[0]) {
|
||||
die "conlleval: -d requires delimiter character";
|
||||
}
|
||||
$delimiter = shift(@ARGV);
|
||||
} elsif ($ARGV[0] eq "-o") {
|
||||
shift(@ARGV);
|
||||
if (not defined $ARGV[0]) {
|
||||
die "conlleval: -o requires delimiter character";
|
||||
}
|
||||
$oTag = shift(@ARGV);
|
||||
} else { die "conlleval: unknown argument $ARGV[0]\n"; }
|
||||
}
|
||||
if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
|
||||
# process input
|
||||
while (<STDIN>) {
|
||||
chomp($line = $_);
|
||||
@features = split(/$delimiter/,$line);
|
||||
if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
|
||||
elsif ($nbrOfFeatures != $#features and @features != 0) {
|
||||
printf STDERR "unexpected number of features: %d (%d)\n",
|
||||
$#features+1,$nbrOfFeatures+1;
|
||||
exit(1);
|
||||
}
|
||||
if (@features == 0 or
|
||||
$features[0] eq $boundary) { @features = ($boundary,"O","O"); }
|
||||
if (@features < 2) {
|
||||
die "conlleval: unexpected number of features in line $line\n";
|
||||
}
|
||||
if ($raw) {
|
||||
if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
|
||||
if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
|
||||
if ($features[$#features] ne "O") {
|
||||
$features[$#features] = "B-$features[$#features]";
|
||||
}
|
||||
if ($features[$#features-1] ne "O") {
|
||||
$features[$#features-1] = "B-$features[$#features-1]";
|
||||
}
|
||||
}
|
||||
# 20040126 ET code which allows hyphens in the types
|
||||
if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
|
||||
$guessed = $1;
|
||||
$guessedType = $2;
|
||||
} else {
|
||||
$guessed = $features[$#features];
|
||||
$guessedType = "";
|
||||
}
|
||||
pop(@features);
|
||||
if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
|
||||
$correct = $1;
|
||||
$correctType = $2;
|
||||
} else {
|
||||
$correct = $features[$#features];
|
||||
$correctType = "";
|
||||
}
|
||||
pop(@features);
|
||||
# ($guessed,$guessedType) = split(/-/,pop(@features));
|
||||
# ($correct,$correctType) = split(/-/,pop(@features));
|
||||
$guessedType = $guessedType ? $guessedType : "";
|
||||
$correctType = $correctType ? $correctType : "";
|
||||
$firstItem = shift(@features);
|
||||
|
||||
# 1999-06-26 sentence breaks should always be counted as out of chunk
|
||||
if ( $firstItem eq $boundary ) { $guessed = "O"; }
|
||||
|
||||
if ($inCorrect) {
|
||||
if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
|
||||
&endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
|
||||
$lastGuessedType eq $lastCorrectType) {
|
||||
$inCorrect=$false;
|
||||
$correctChunk++;
|
||||
$correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
|
||||
$correctChunk{$lastCorrectType}+1 : 1;
|
||||
} elsif (
|
||||
&endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
|
||||
&endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
|
||||
$guessedType ne $correctType ) {
|
||||
$inCorrect=$false;
|
||||
}
|
||||
}
|
||||
|
||||
if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
|
||||
&startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
|
||||
$guessedType eq $correctType) { $inCorrect = $true; }
|
||||
|
||||
if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
|
||||
$foundCorrect++;
|
||||
$foundCorrect{$correctType} = $foundCorrect{$correctType} ?
|
||||
$foundCorrect{$correctType}+1 : 1;
|
||||
}
|
||||
if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
|
||||
$foundGuessed++;
|
||||
$foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
|
||||
$foundGuessed{$guessedType}+1 : 1;
|
||||
}
|
||||
if ( $firstItem ne $boundary ) {
|
||||
if ( $correct eq $guessed and $guessedType eq $correctType ) {
|
||||
$correctTags++;
|
||||
}
|
||||
$tokenCounter++;
|
||||
}
|
||||
|
||||
$lastGuessed = $guessed;
|
||||
$lastCorrect = $correct;
|
||||
$lastGuessedType = $guessedType;
|
||||
$lastCorrectType = $correctType;
|
||||
}
|
||||
if ($inCorrect) {
|
||||
$correctChunk++;
|
||||
$correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
|
||||
$correctChunk{$lastCorrectType}+1 : 1;
|
||||
}
|
||||
|
||||
if (not $latex) {
|
||||
# compute overall precision, recall and FB1 (default values are 0.0)
|
||||
$precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
|
||||
$recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
|
||||
$FB1 = 2*$precision*$recall/($precision+$recall)
|
||||
if ($precision+$recall > 0);
|
||||
|
||||
# print overall performance
|
||||
printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
|
||||
printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
|
||||
if ($tokenCounter>0) {
|
||||
printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
|
||||
printf "precision: %6.2f%%; ",$precision;
|
||||
printf "recall: %6.2f%%; ",$recall;
|
||||
printf "FB1: %6.2f\n",$FB1;
|
||||
}
|
||||
}
|
||||
|
||||
# sort chunk type names
|
||||
undef($lastType);
|
||||
@sortedTypes = ();
|
||||
foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
|
||||
if (not($lastType) or $lastType ne $i) {
|
||||
push(@sortedTypes,($i));
|
||||
}
|
||||
$lastType = $i;
|
||||
}
|
||||
# print performance per chunk type
|
||||
if (not $latex) {
|
||||
for $i (@sortedTypes) {
|
||||
$correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
|
||||
if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
|
||||
else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
|
||||
if (not($foundCorrect{$i})) { $recall = 0.0; }
|
||||
else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
|
||||
if ($precision+$recall == 0.0) { $FB1 = 0.0; }
|
||||
else { $FB1 = 2*$precision*$recall/($precision+$recall); }
|
||||
printf "%17s: ",$i;
|
||||
printf "precision: %6.2f%%; ",$precision;
|
||||
printf "recall: %6.2f%%; ",$recall;
|
||||
printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
|
||||
}
|
||||
} else {
|
||||
print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
|
||||
for $i (@sortedTypes) {
|
||||
$correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
|
||||
if (not($foundGuessed{$i})) { $precision = 0.0; }
|
||||
else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
|
||||
if (not($foundCorrect{$i})) { $recall = 0.0; }
|
||||
else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
|
||||
if ($precision+$recall == 0.0) { $FB1 = 0.0; }
|
||||
else { $FB1 = 2*$precision*$recall/($precision+$recall); }
|
||||
printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
|
||||
$i,$precision,$recall,$FB1;
|
||||
}
|
||||
print "\\hline\n";
|
||||
$precision = 0.0;
|
||||
$recall = 0;
|
||||
$FB1 = 0.0;
|
||||
$precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
|
||||
$recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
|
||||
$FB1 = 2*$precision*$recall/($precision+$recall)
|
||||
if ($precision+$recall > 0);
|
||||
printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
|
||||
$precision,$recall,$FB1;
|
||||
}
|
||||
|
||||
exit 0;
|
||||
|
||||
# endOfChunk: checks if a chunk ended between the previous and current word
|
||||
# arguments: previous and current chunk tags, previous and current types
|
||||
# note: this code is capable of handling other chunk representations
|
||||
# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
|
||||
# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
|
||||
|
||||
sub endOfChunk {
|
||||
my $prevTag = shift(@_);
|
||||
my $tag = shift(@_);
|
||||
my $prevType = shift(@_);
|
||||
my $type = shift(@_);
|
||||
my $chunkEnd = $false;
|
||||
|
||||
if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
|
||||
|
||||
if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
|
||||
|
||||
if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
|
||||
$chunkEnd = $true;
|
||||
}
|
||||
|
||||
# corrected 1998-12-22: these chunks are assumed to have length 1
|
||||
if ( $prevTag eq "]" ) { $chunkEnd = $true; }
|
||||
if ( $prevTag eq "[" ) { $chunkEnd = $true; }
|
||||
|
||||
return($chunkEnd);
|
||||
}
|
||||
|
||||
# startOfChunk: checks if a chunk started between the previous and current word
|
||||
# arguments: previous and current chunk tags, previous and current types
|
||||
# note: this code is capable of handling other chunk representations
|
||||
# than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
|
||||
# Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
|
||||
|
||||
sub startOfChunk {
|
||||
my $prevTag = shift(@_);
|
||||
my $tag = shift(@_);
|
||||
my $prevType = shift(@_);
|
||||
my $type = shift(@_);
|
||||
my $chunkStart = $false;
|
||||
|
||||
if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
|
||||
if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
|
||||
if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
|
||||
if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
|
||||
|
||||
if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
|
||||
if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
|
||||
if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
|
||||
if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
|
||||
|
||||
if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
|
||||
$chunkStart = $true;
|
||||
}
|
||||
|
||||
# corrected 1998-12-22: these chunks are assumed to have length 1
|
||||
if ( $tag eq "[" ) { $chunkStart = $true; }
|
||||
if ( $tag eq "]" ) { $chunkStart = $true; }
|
||||
|
||||
return($chunkStart);
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
# Python version of the evaluation script from CoNLL'00-
|
||||
# Originates from: https://github.com/spyysalo/conlleval.py
|
||||
|
||||
|
||||
# Intentional differences:
|
||||
# - accept any space as delimiter by default
|
||||
# - optional file argument (default STDIN)
|
||||
# - option to set boundary (-b argument)
|
||||
# - LaTeX output (-l argument) not supported
|
||||
# - raw tags (-r argument) not supported
|
||||
|
||||
import sys
|
||||
import re
|
||||
import codecs
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
ANY_SPACE = '<SPACE>'
|
||||
|
||||
|
||||
class FormatError(Exception):
|
||||
pass
|
||||
|
||||
Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
|
||||
|
||||
|
||||
class EvalCounts(object):
|
||||
def __init__(self):
|
||||
self.correct_chunk = 0 # number of correctly identified chunks
|
||||
self.correct_tags = 0 # number of correct chunk tags
|
||||
self.found_correct = 0 # number of chunks in corpus
|
||||
self.found_guessed = 0 # number of identified chunks
|
||||
self.token_counter = 0 # token counter (ignores sentence breaks)
|
||||
|
||||
# counts by type
|
||||
self.t_correct_chunk = defaultdict(int)
|
||||
self.t_found_correct = defaultdict(int)
|
||||
self.t_found_guessed = defaultdict(int)
|
||||
|
||||
|
||||
def parse_args(argv):
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(
|
||||
description='evaluate tagging results using CoNLL criteria',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
arg = parser.add_argument
|
||||
arg('-b', '--boundary', metavar='STR', default='-X-',
|
||||
help='sentence boundary')
|
||||
arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
|
||||
help='character delimiting items in input')
|
||||
arg('-o', '--otag', metavar='CHAR', default='O',
|
||||
help='alternative outside tag')
|
||||
arg('file', nargs='?', default=None)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
def parse_tag(t):
|
||||
m = re.match(r'^([^-]*)-(.*)$', t)
|
||||
return m.groups() if m else (t, '')
|
||||
|
||||
|
||||
def evaluate(iterable, options=None):
|
||||
if options is None:
|
||||
options = parse_args([]) # use defaults
|
||||
|
||||
counts = EvalCounts()
|
||||
num_features = None # number of features per line
|
||||
in_correct = False # currently processed chunks is correct until now
|
||||
last_correct = 'O' # previous chunk tag in corpus
|
||||
last_correct_type = '' # type of previously identified chunk tag
|
||||
last_guessed = 'O' # previously identified chunk tag
|
||||
last_guessed_type = '' # type of previous chunk tag in corpus
|
||||
|
||||
for line in iterable:
|
||||
line = line.rstrip('\r\n')
|
||||
|
||||
if options.delimiter == ANY_SPACE:
|
||||
features = line.split()
|
||||
else:
|
||||
features = line.split(options.delimiter)
|
||||
|
||||
if num_features is None:
|
||||
num_features = len(features)
|
||||
elif num_features != len(features) and len(features) != 0:
|
||||
raise FormatError('unexpected number of features: %d (%d)' %
|
||||
(len(features), num_features))
|
||||
|
||||
if len(features) == 0 or features[0] == options.boundary:
|
||||
features = [options.boundary, 'O', 'O']
|
||||
if len(features) < 3:
|
||||
raise FormatError('unexpected number of features in line %s' % line)
|
||||
|
||||
guessed, guessed_type = parse_tag(features.pop())
|
||||
correct, correct_type = parse_tag(features.pop())
|
||||
first_item = features.pop(0)
|
||||
|
||||
if first_item == options.boundary:
|
||||
guessed = 'O'
|
||||
|
||||
end_correct = end_of_chunk(last_correct, correct,
|
||||
last_correct_type, correct_type)
|
||||
end_guessed = end_of_chunk(last_guessed, guessed,
|
||||
last_guessed_type, guessed_type)
|
||||
start_correct = start_of_chunk(last_correct, correct,
|
||||
last_correct_type, correct_type)
|
||||
start_guessed = start_of_chunk(last_guessed, guessed,
|
||||
last_guessed_type, guessed_type)
|
||||
|
||||
if in_correct:
|
||||
if (end_correct and end_guessed and
|
||||
last_guessed_type == last_correct_type):
|
||||
in_correct = False
|
||||
counts.correct_chunk += 1
|
||||
counts.t_correct_chunk[last_correct_type] += 1
|
||||
elif (end_correct != end_guessed or guessed_type != correct_type):
|
||||
in_correct = False
|
||||
|
||||
if start_correct and start_guessed and guessed_type == correct_type:
|
||||
in_correct = True
|
||||
|
||||
if start_correct:
|
||||
counts.found_correct += 1
|
||||
counts.t_found_correct[correct_type] += 1
|
||||
if start_guessed:
|
||||
counts.found_guessed += 1
|
||||
counts.t_found_guessed[guessed_type] += 1
|
||||
if first_item != options.boundary:
|
||||
if correct == guessed and guessed_type == correct_type:
|
||||
counts.correct_tags += 1
|
||||
counts.token_counter += 1
|
||||
|
||||
last_guessed = guessed
|
||||
last_correct = correct
|
||||
last_guessed_type = guessed_type
|
||||
last_correct_type = correct_type
|
||||
|
||||
if in_correct:
|
||||
counts.correct_chunk += 1
|
||||
counts.t_correct_chunk[last_correct_type] += 1
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def uniq(iterable):
|
||||
seen = set()
|
||||
return [i for i in iterable if not (i in seen or seen.add(i))]
|
||||
|
||||
|
||||
def calculate_metrics(correct, guessed, total):
|
||||
tp, fp, fn = correct, guessed-correct, total-correct
|
||||
p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
|
||||
r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
|
||||
f = 0 if p + r == 0 else 2 * p * r / (p + r)
|
||||
return Metrics(tp, fp, fn, p, r, f)
|
||||
|
||||
|
||||
def metrics(counts):
|
||||
c = counts
|
||||
overall = calculate_metrics(
|
||||
c.correct_chunk, c.found_guessed, c.found_correct
|
||||
)
|
||||
by_type = {}
|
||||
for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
|
||||
by_type[t] = calculate_metrics(
|
||||
c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
|
||||
)
|
||||
return overall, by_type
|
||||
|
||||
|
||||
def report(counts, out=None):
|
||||
if out is None:
|
||||
out = sys.stdout
|
||||
|
||||
overall, by_type = metrics(counts)
|
||||
|
||||
c = counts
|
||||
out.write('processed %d tokens with %d phrases; ' %
|
||||
(c.token_counter, c.found_correct))
|
||||
out.write('found: %d phrases; correct: %d.\n' %
|
||||
(c.found_guessed, c.correct_chunk))
|
||||
|
||||
if c.token_counter > 0:
|
||||
out.write('accuracy: %6.2f%%; ' %
|
||||
(100.*c.correct_tags/c.token_counter))
|
||||
out.write('precision: %6.2f%%; ' % (100.*overall.prec))
|
||||
out.write('recall: %6.2f%%; ' % (100.*overall.rec))
|
||||
out.write('FB1: %6.2f\n' % (100.*overall.fscore))
|
||||
|
||||
for i, m in sorted(by_type.items()):
|
||||
out.write('%17s: ' % i)
|
||||
out.write('precision: %6.2f%%; ' % (100.*m.prec))
|
||||
out.write('recall: %6.2f%%; ' % (100.*m.rec))
|
||||
out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
|
||||
|
||||
|
||||
def report_notprint(counts, out=None):
|
||||
if out is None:
|
||||
out = sys.stdout
|
||||
|
||||
overall, by_type = metrics(counts)
|
||||
|
||||
c = counts
|
||||
final_report = []
|
||||
line = []
|
||||
line.append('processed %d tokens with %d phrases; ' %
|
||||
(c.token_counter, c.found_correct))
|
||||
line.append('found: %d phrases; correct: %d.\n' %
|
||||
(c.found_guessed, c.correct_chunk))
|
||||
final_report.append("".join(line))
|
||||
|
||||
if c.token_counter > 0:
|
||||
line = []
|
||||
line.append('accuracy: %6.2f%%; ' %
|
||||
(100.*c.correct_tags/c.token_counter))
|
||||
line.append('precision: %6.2f%%; ' % (100.*overall.prec))
|
||||
line.append('recall: %6.2f%%; ' % (100.*overall.rec))
|
||||
line.append('FB1: %6.2f\n' % (100.*overall.fscore))
|
||||
final_report.append("".join(line))
|
||||
|
||||
for i, m in sorted(by_type.items()):
|
||||
line = []
|
||||
line.append('%17s: ' % i)
|
||||
line.append('precision: %6.2f%%; ' % (100.*m.prec))
|
||||
line.append('recall: %6.2f%%; ' % (100.*m.rec))
|
||||
line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i]))
|
||||
final_report.append("".join(line))
|
||||
return final_report
|
||||
|
||||
|
||||
def end_of_chunk(prev_tag, tag, prev_type, type_):
|
||||
# check if a chunk ended between the previous and current word
|
||||
# arguments: previous and current chunk tags, previous and current types
|
||||
chunk_end = False
|
||||
|
||||
if prev_tag == 'E': chunk_end = True
|
||||
if prev_tag == 'S': chunk_end = True
|
||||
|
||||
if prev_tag == 'B' and tag == 'B': chunk_end = True
|
||||
if prev_tag == 'B' and tag == 'S': chunk_end = True
|
||||
if prev_tag == 'B' and tag == 'O': chunk_end = True
|
||||
if prev_tag == 'I' and tag == 'B': chunk_end = True
|
||||
if prev_tag == 'I' and tag == 'S': chunk_end = True
|
||||
if prev_tag == 'I' and tag == 'O': chunk_end = True
|
||||
|
||||
if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
|
||||
chunk_end = True
|
||||
|
||||
# these chunks are assumed to have length 1
|
||||
if prev_tag == ']': chunk_end = True
|
||||
if prev_tag == '[': chunk_end = True
|
||||
|
||||
return chunk_end
|
||||
|
||||
|
||||
def start_of_chunk(prev_tag, tag, prev_type, type_):
|
||||
# check if a chunk started between the previous and current word
|
||||
# arguments: previous and current chunk tags, previous and current types
|
||||
chunk_start = False
|
||||
|
||||
if tag == 'B': chunk_start = True
|
||||
if tag == 'S': chunk_start = True
|
||||
|
||||
if prev_tag == 'E' and tag == 'E': chunk_start = True
|
||||
if prev_tag == 'E' and tag == 'I': chunk_start = True
|
||||
if prev_tag == 'S' and tag == 'E': chunk_start = True
|
||||
if prev_tag == 'S' and tag == 'I': chunk_start = True
|
||||
if prev_tag == 'O' and tag == 'E': chunk_start = True
|
||||
if prev_tag == 'O' and tag == 'I': chunk_start = True
|
||||
|
||||
if tag != 'O' and tag != '.' and prev_type != type_:
|
||||
chunk_start = True
|
||||
|
||||
# these chunks are assumed to have length 1
|
||||
if tag == '[': chunk_start = True
|
||||
if tag == ']': chunk_start = True
|
||||
|
||||
return chunk_start
|
||||
|
||||
|
||||
def return_report(input_file):
|
||||
with codecs.open(input_file, "r", "utf8") as f:
|
||||
counts = evaluate(f)
|
||||
return report_notprint(counts)
|
||||
|
||||
|
||||
def main(argv):
|
||||
args = parse_args(argv[1:])
|
||||
|
||||
if args.file is None:
|
||||
counts = evaluate(sys.stdin, args)
|
||||
else:
|
||||
with open(args.file) as f:
|
||||
counts = evaluate(f, args)
|
||||
report(counts)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv))
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,123 @@
|
|||
|
||||
import codecs
|
||||
import data_utils
|
||||
def load_sentences(path):
|
||||
"""
|
||||
加载数据集,每一行至少包含一个汉字和一个标记
|
||||
句子和句子之间是以空格进行分割
|
||||
最后返回句子集合
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
# 存放数据集
|
||||
sentences = []
|
||||
# 临时存放每一个句子
|
||||
sentence = []
|
||||
for line in codecs.open(path, 'r', encoding='utf-8'):
|
||||
# 去掉两边空格
|
||||
line = line.strip()
|
||||
# 首先判断是不是空,如果是则表示句子和句子之间的分割点
|
||||
if not line:
|
||||
if len(sentence) > 0:
|
||||
sentences.append(sentence)
|
||||
# 清空sentence表示一句话完结
|
||||
sentence = []
|
||||
else:
|
||||
if line[0] == " ":
|
||||
continue
|
||||
else:
|
||||
word = line.split()
|
||||
assert len(word) >= 2
|
||||
sentence.append(word)
|
||||
# 循环走完,要判断一下,防止最后一个句子没有进入到句子集合中
|
||||
if len(sentence) > 0:
|
||||
sentences.append(sentence)
|
||||
return sentences
|
||||
|
||||
def update_tag_scheme(sentences, tag_scheme):
|
||||
"""
|
||||
更新为指定编码
|
||||
:param sentences:
|
||||
:param tag_scheme:
|
||||
:return:
|
||||
"""
|
||||
for i, s in enumerate(sentences):
|
||||
tags = [w[-1] for w in s]
|
||||
if not data_utils.check_bio(tags):
|
||||
s_str = "\n".join(" ".join(w) for w in s)
|
||||
raise Exception("输入的句子应为BIO编码,请检查输入句子%i:\n%s" % (i, s_str))
|
||||
|
||||
if tag_scheme == "BIO":
|
||||
for word, new_tag in zip(s, tags):
|
||||
word[-1] = new_tag
|
||||
|
||||
if tag_scheme == "BIOES":
|
||||
new_tags = data_utils.bio_to_bioes(tags)
|
||||
for word, new_tag in zip(s, new_tags):
|
||||
word[-1] = new_tag
|
||||
else:
|
||||
raise Exception("非法目标编码")
|
||||
|
||||
def word_mapping(sentences):
|
||||
"""
|
||||
构建字典
|
||||
:param sentences:
|
||||
:return:
|
||||
"""
|
||||
word_list = [[x[0] for x in s] for s in sentences]
|
||||
dico = data_utils.create_dico(word_list)
|
||||
dico['<PAD>'] = 10000001
|
||||
dico['<UNK>'] = 10000000
|
||||
word_to_id, id_to_word = data_utils.create_mapping(dico)
|
||||
return dico, word_to_id, id_to_word
|
||||
|
||||
def tag_mapping(sentences):
|
||||
"""
|
||||
构建标签字典
|
||||
:param sentences:
|
||||
:return:
|
||||
"""
|
||||
tag_list = [[x[1] for x in s] for s in sentences]
|
||||
dico = data_utils.create_dico(tag_list)
|
||||
tag_to_id, id_to_tag = data_utils.create_mapping(dico)
|
||||
return dico, tag_to_id, id_to_tag
|
||||
|
||||
def prepare_dataset(sentences, word_to_id, tag_to_id, train=True):
|
||||
"""
|
||||
数据预处理,返回list其实包含
|
||||
-word_list
|
||||
-word_id_list
|
||||
-word char indexs
|
||||
-tag_id_list
|
||||
:param sentences:
|
||||
:param word_to_id:
|
||||
:param tag_to_id:
|
||||
:param train:
|
||||
:return:
|
||||
"""
|
||||
none_index = tag_to_id['O']
|
||||
|
||||
data = []
|
||||
for s in sentences:
|
||||
word_list = [ w[0] for w in s]
|
||||
word_id_list = [word_to_id[w if w in word_to_id else '<UNK>'] for w in word_list]
|
||||
segs = data_utils.get_seg_features("".join(word_list))
|
||||
if train:
|
||||
tag_id_list = [tag_to_id[w[-1]] for w in s]
|
||||
else:
|
||||
tag_id_list = [none_index for w in s]
|
||||
data.append([word_list, word_id_list, segs,tag_id_list])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
path = "data/ner.dev"
|
||||
sentences = load_sentences(path)
|
||||
update_tag_scheme(sentences,"BIOES")
|
||||
_, word_to_id, id_to_word = word_mapping(sentences)
|
||||
_, tag_to_id, id_to_tag = tag_mapping(sentences)
|
||||
dev_data = prepare_dataset(sentences, word_to_id, tag_to_id)
|
||||
data_utils.BatchManager(dev_data, 120)
|
|
@ -0,0 +1,260 @@
|
|||
|
||||
import jieba
|
||||
import math
|
||||
import random
|
||||
import codecs
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
def check_bio(tags):
|
||||
"""
|
||||
检测输入的tags是否是bio编码
|
||||
如果不是bio编码
|
||||
那么错误的类型
|
||||
(1)编码不在BIO中
|
||||
(2)第一个编码是I
|
||||
(3)当前编码不是B,前一个编码不是O
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
for i, tag in enumerate(tags):
|
||||
if tag == 'O':
|
||||
continue
|
||||
tag_list = tag.split("-")
|
||||
if len(tag_list) != 2 or tag_list[0] not in set(['B','I']):
|
||||
#非法编码
|
||||
return False
|
||||
if tag_list[0] == 'B':
|
||||
continue
|
||||
elif i == 0 or tags[i-1] == 'O':
|
||||
#如果第一个位置不是B或者当前编码不是B并且前一个编码0,则全部转换成B
|
||||
tags[i] = 'B' + tag[1:]
|
||||
elif tags[i-1][1:] == tag[1:]:
|
||||
# 如果当前编码的后面类型编码与tags中的前一个编码中后面类型编码相同则跳过
|
||||
continue
|
||||
else:
|
||||
# 如果编码类型不一致,则重新从B开始编码
|
||||
tags[i] = 'B' + tag[1:]
|
||||
return True
|
||||
|
||||
def bio_to_bioes(tags):
|
||||
"""
|
||||
把bio编码转换成bioes编码
|
||||
返回新的tags
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
new_tags = []
|
||||
for i, tag in enumerate(tags):
|
||||
if tag == 'O':
|
||||
# 直接保留,不变化
|
||||
new_tags.append(tag)
|
||||
elif tag.split('-')[0] == 'B':
|
||||
# 如果tag是以B开头,那么我们就要做下面的判断
|
||||
# 首先,如果当前tag不是最后一个,并且紧跟着的后一个是I
|
||||
if (i+1) < len(tags) and tags[i+1].split('-')[0] == 'I':
|
||||
# 直接保留
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
# 如果是最后一个或者紧跟着的后一个不是I,那么表示单子,需要把B换成S表示单字
|
||||
new_tags.append(tag.replace('B-','S-'))
|
||||
elif tag.split('-')[0] == 'I':
|
||||
# 如果tag是以I开头,那么我们需要进行下面的判断
|
||||
# 首先,如果当前tag不是最后一个,并且紧跟着的一个是I
|
||||
if (i+1) < len(tags) and tags[i+1].split('-')[0] == 'I':
|
||||
# 直接保留
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
# 如果是最后一个,或者后一个不是I开头的,那么就表示一个词的结尾,就把I换成E表示一个词结尾
|
||||
new_tags.append(tag.replace('I-', 'E-'))
|
||||
|
||||
else:
|
||||
raise Exception('非法编码')
|
||||
return new_tags
|
||||
|
||||
def bioes_to_bio(tags):
|
||||
"""
|
||||
BIOES->BIO
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
new_tags = []
|
||||
for i, tag in enumerate(tags):
|
||||
if tag.split('-')[0] == "B":
|
||||
new_tags.append(tag)
|
||||
elif tag.split('-')[0] == "I":
|
||||
new_tags.append(tag)
|
||||
elif tag.split('-')[0] == "S":
|
||||
new_tags.append(tag.replace('S-','B-'))
|
||||
elif tag.split('-')[0] == "E":
|
||||
new_tags.append(tag.replace('E-','I-'))
|
||||
elif tag.split('-')[0] == "O":
|
||||
new_tags.append(tag)
|
||||
else:
|
||||
raise Exception('非法编码格式')
|
||||
return new_tags
|
||||
|
||||
|
||||
def create_dico(item_list):
|
||||
"""
|
||||
对于item_list中的每一个items,统计items中item在item_list中的次数
|
||||
item:出现的次数
|
||||
:param item_list:
|
||||
:return:
|
||||
"""
|
||||
assert type(item_list) is list
|
||||
dico = {}
|
||||
for items in item_list:
|
||||
for item in items:
|
||||
if item not in dico:
|
||||
dico[item] = 1
|
||||
else:
|
||||
dico[item] += 1
|
||||
return dico
|
||||
|
||||
def create_mapping(dico):
|
||||
"""
|
||||
创建item to id, id_to_item
|
||||
item的排序按词典中出现的次数
|
||||
:param dico:
|
||||
:return:
|
||||
"""
|
||||
sorted_items = sorted(dico.items(), key=lambda x:(-x[1],x[0]))
|
||||
id_to_item = {i:v[0] for i,v in enumerate(sorted_items)}
|
||||
item_to_id = {v:k for k, v in id_to_item.items()}
|
||||
return item_to_id, id_to_item
|
||||
|
||||
def get_seg_features(words):
|
||||
"""
|
||||
利用jieba分词
|
||||
采用类似bioes的编码,0表示单个字成词, 1表示一个词的开始, 2表示一个词的中间,3表示一个词的结尾
|
||||
:param words:
|
||||
:return:
|
||||
"""
|
||||
seg_features = []
|
||||
|
||||
word_list = list(jieba.cut(words))
|
||||
|
||||
for word in word_list:
|
||||
if len(word) == 1:
|
||||
seg_features.append(0)
|
||||
else:
|
||||
temp = [2] * len(word)
|
||||
temp[0] = 1
|
||||
temp[-1] = 3
|
||||
seg_features.extend(temp)
|
||||
return seg_features
|
||||
|
||||
def load_word2vec(emb_file, id_to_word, word_dim, old_weights):
|
||||
"""
|
||||
:param emb_file:
|
||||
:param id_to_word:
|
||||
:param word_dim:
|
||||
:param old_weights:
|
||||
:return:
|
||||
"""
|
||||
new_weights = old_weights
|
||||
pre_trained = {}
|
||||
emb_invalid = 0
|
||||
for i, line in enumerate(codecs.open(emb_file, 'r', encoding='utf-8')):
|
||||
line = line.rstrip().split()
|
||||
if len(line) == word_dim + 1:
|
||||
pre_trained[line[0]] = np.array(
|
||||
[float(x) for x in line[1:]]
|
||||
).astype(np.float32)
|
||||
else:
|
||||
emb_invalid = emb_invalid + 1
|
||||
|
||||
if emb_invalid > 0:
|
||||
print('waring: %i invalid lines' % emb_invalid)
|
||||
|
||||
num_words = len(id_to_word)
|
||||
for i in range(num_words):
|
||||
word = id_to_word[i]
|
||||
if word in pre_trained:
|
||||
new_weights[i] = pre_trained[word]
|
||||
else:
|
||||
pass
|
||||
print('加载了 %i 个字向量' % len(pre_trained))
|
||||
|
||||
return new_weights
|
||||
|
||||
def augment_with_pretrained(dico_train, emb_path, test_words):
|
||||
"""
|
||||
:param dico_train:
|
||||
:param emb_path:
|
||||
:param test_words:
|
||||
:return:
|
||||
"""
|
||||
assert os.path.isfile(emb_path)
|
||||
|
||||
#加载与训练的词向量
|
||||
pretrained = set(
|
||||
[
|
||||
line.rsplit()[0].strip() for line in codecs.open(emb_path, 'r', encoding='utf-8')
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if test_words is None:
|
||||
for word in pretrained:
|
||||
if word not in dico_train:
|
||||
dico_train[word] = 0
|
||||
else:
|
||||
for word in test_words:
|
||||
if any(x in pretrained for x in
|
||||
[word, word.lower()]
|
||||
) and word not in dico_train:
|
||||
dico_train[word] = 0
|
||||
|
||||
word_to_id, id_to_word = create_mapping(dico_train)
|
||||
|
||||
return dico_train, word_to_id, id_to_word
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BatchManager(object):
|
||||
def __init__(self, data, batch_size):
|
||||
self.batch_data = self.sort_and_pad(data, batch_size)
|
||||
self.len_data = len(self.batch_data)
|
||||
def sort_and_pad(self, data, batch_size):
|
||||
num_batch = int(math.ceil(len(data) / batch_size))
|
||||
sorted_data = sorted(data, key=lambda x:len(x[0]))
|
||||
batch_data = list()
|
||||
for i in range(num_batch):
|
||||
batch_data.append(self.pad_data(sorted_data[i*batch_size : (i+1)*batch_size]))
|
||||
return batch_data
|
||||
|
||||
@staticmethod
|
||||
def pad_data(data):
|
||||
word_list = []
|
||||
word_id_list = []
|
||||
seg_list = []
|
||||
tag_id_list = []
|
||||
max_length = max([len(sentence[0]) for sentence in data])
|
||||
for line in data:
|
||||
words, word_ids, segs, tag_ids = line
|
||||
padding = [0] * (max_length - len(words))
|
||||
word_list.append(words + padding)
|
||||
word_id_list.append(word_ids + padding)
|
||||
seg_list.append(segs + padding)
|
||||
tag_id_list.append(tag_ids + padding)
|
||||
return [word_list, word_id_list, seg_list,tag_id_list]
|
||||
|
||||
def iter_batch(self, shuffle=False):
|
||||
if shuffle:
|
||||
random.shuffle(self.batch_data)
|
||||
for idx in range(self.len_data):
|
||||
yield self.batch_data[idx]
|
||||
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
|
||||
# 系统
|
||||
import os
|
||||
import tensorflow as tf
|
||||
import pickle
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
# 自定义
|
||||
import data_utils
|
||||
import data_loader
|
||||
import model_utils
|
||||
from model import Model
|
||||
|
||||
from data_utils import load_word2vec
|
||||
|
||||
flags = tf.app.flags
|
||||
|
||||
# 训练相关的
|
||||
flags.DEFINE_boolean('train', False, '是否开始训练')
|
||||
flags.DEFINE_boolean('clean', False, '是否清理文件')
|
||||
|
||||
#配置相关
|
||||
flags.DEFINE_integer('seg_dim', 20, 'seg embedding size')
|
||||
flags.DEFINE_integer('word_dim', 100, 'word embedding')
|
||||
flags.DEFINE_integer('lstm_dim', 100, 'Num of hidden unis in lstm')
|
||||
flags.DEFINE_string('tag_schema', 'BIOES', '编码方式')
|
||||
|
||||
#训练相关
|
||||
flags.DEFINE_float('clip', 5, 'Grandient clip')
|
||||
flags.DEFINE_float('dropout', 0.5, 'Dropout rate')
|
||||
flags.DEFINE_integer('batch_size', 120, 'batch_size')
|
||||
flags.DEFINE_float('lr', 0.001, 'learning rate')
|
||||
flags.DEFINE_string('optimizer', 'adam', '优化器')
|
||||
flags.DEFINE_boolean('pre_emb', True, '是否使用预训练')
|
||||
|
||||
flags.DEFINE_integer('max_epoch', 100, '最大轮训次数')
|
||||
flags.DEFINE_integer('setps_chech', 100, 'steps per checkpoint')
|
||||
flags.DEFINE_string('ckpt_path', os.path.join('modelfile', 'ckpt'), '保存模型的位置')
|
||||
flags.DEFINE_string('log_file', 'train.log', '训练过程中日志')
|
||||
flags.DEFINE_string('map_file', 'maps.pkl', '存放字典映射及标签映射')
|
||||
flags.DEFINE_string('vocab_file', 'vocab.json', '字向量')
|
||||
flags.DEFINE_string('config_file', 'config_file', '配置文件')
|
||||
flags.DEFINE_string('result_path', 'result', '结果路径')
|
||||
flags.DEFINE_string('emb_file', os.path.join('data', 'wiki_100.utf8'), '词向量文件路径')
|
||||
flags.DEFINE_string('train_file', os.path.join('data', 'ner.train'), '训练数据路径')
|
||||
flags.DEFINE_string('dev_file', os.path.join('data', 'ner.dev'), '校验数据路径')
|
||||
flags.DEFINE_string('test_file', os.path.join('data', 'ner.test'), '测试数据路径')
|
||||
|
||||
flags.DEFINE_string('model_type',"idcnn", "模型可以选择bilstm或者idcnn")
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
assert FLAGS.clip < 5.1,'梯度裁剪不能过大'
|
||||
assert 0< FLAGS.dropout < 1, 'dropout必须在0和1之间'
|
||||
assert FLAGS.lr > 0, 'lr 必须大于0'
|
||||
assert FLAGS.optimizer in ['adam', 'sgd', 'adagrad'], '优化器必须在adam, sgd, adagrad'
|
||||
|
||||
def evaluate(sess, model, name, manager, id_to_tag, logger):
|
||||
logger.info('evaluate:{}'.format(name))
|
||||
ner_results = model.evaluate(sess, manager, id_to_tag)
|
||||
eval_lines = model_utils.test_ner(ner_results, FLAGS.result_path)
|
||||
for line in eval_lines:
|
||||
logger.info(line)
|
||||
f1 = float(eval_lines[1].strip().split()[-1])
|
||||
|
||||
if name == "dev":
|
||||
best_test_f1 = model.best_dev_f1.eval()
|
||||
if f1 > best_test_f1:
|
||||
tf.assign(model.best_dev_f1, f1).eval()
|
||||
logger.info('new best dev f1 socre:{:>.3f}'.format(f1))
|
||||
return f1 > best_test_f1
|
||||
elif name == "test":
|
||||
best_test_f1 = model.best_test_f1.eval()
|
||||
if f1 > best_test_f1:
|
||||
tf.assign(model.best_test_f1, f1).eval()
|
||||
logger.info('new best test f1 score:{:>.3f}'.format(f1))
|
||||
return f1 > best_test_f1
|
||||
|
||||
|
||||
|
||||
def train():
|
||||
# 加载数据集
|
||||
train_sentences = data_loader.load_sentences(FLAGS.train_file)
|
||||
dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
|
||||
test_sentences = data_loader.load_sentences(FLAGS.test_file)
|
||||
|
||||
# 转换编码 bio转bioes
|
||||
data_loader.update_tag_scheme(train_sentences, FLAGS.tag_schema)
|
||||
data_loader.update_tag_scheme(test_sentences, FLAGS.tag_schema)
|
||||
data_loader.update_tag_scheme(dev_sentences, FLAGS.tag_schema)
|
||||
|
||||
# 创建单词映射及标签映射
|
||||
if not os.path.isfile(FLAGS.map_file):
|
||||
if FLAGS.pre_emb:
|
||||
dico_words_train = data_loader.word_mapping(train_sentences)[0]
|
||||
dico_word, word_to_id, id_to_word = data_utils.augment_with_pretrained(
|
||||
dico_words_train.copy(),
|
||||
FLAGS.emb_file,
|
||||
list(
|
||||
itertools.chain.from_iterable(
|
||||
[[w[0] for w in s] for s in test_sentences]
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
_, word_to_id, id_to_word = data_loader.word_mapping(train_sentences)
|
||||
|
||||
_, tag_to_id, id_to_tag = data_loader.tag_mapping(train_sentences)
|
||||
|
||||
with open(FLAGS.map_file, "wb") as f:
|
||||
pickle.dump([word_to_id, id_to_word, tag_to_id, id_to_tag], f)
|
||||
else:
|
||||
with open(FLAGS.map_file, 'rb') as f:
|
||||
word_to_id, id_to_word, tag_to_id, id_to_tag = pickle.load(f)
|
||||
|
||||
train_data = data_loader.prepare_dataset(
|
||||
train_sentences, word_to_id, tag_to_id
|
||||
)
|
||||
|
||||
dev_data = data_loader.prepare_dataset(
|
||||
dev_sentences, word_to_id, tag_to_id
|
||||
)
|
||||
|
||||
test_data = data_loader.prepare_dataset(
|
||||
test_sentences, word_to_id, tag_to_id
|
||||
)
|
||||
|
||||
train_manager = data_utils.BatchManager(train_data, FLAGS.batch_size)
|
||||
dev_manager = data_utils.BatchManager(dev_data, FLAGS.batch_size)
|
||||
test_manager = data_utils.BatchManager(test_data, FLAGS.batch_size)
|
||||
|
||||
print('train_data_num %i, dev_data_num %i, test_data_num %i' % (len(train_data), len(dev_data), len(test_data)))
|
||||
|
||||
model_utils.make_path(FLAGS)
|
||||
|
||||
if os.path.isfile(FLAGS.config_file):
|
||||
config = model_utils.load_config(FLAGS.config_file)
|
||||
else:
|
||||
config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)
|
||||
model_utils.save_config(config, FLAGS.config_file)
|
||||
|
||||
log_path = os.path.join("log", FLAGS.log_file)
|
||||
logger = model_utils.get_logger(log_path)
|
||||
model_utils.print_config(config, logger)
|
||||
|
||||
tf_config = tf.ConfigProto()
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
steps_per_epoch =train_manager.len_data
|
||||
with tf.Session(config = tf_config) as sess:
|
||||
model = model_utils.create(sess, Model, FLAGS.ckpt_path, load_word2vec, config, id_to_word, logger)
|
||||
logger.info("开始训练")
|
||||
loss = []
|
||||
for i in range(100):
|
||||
for batch in train_manager.iter_batch(shuffle=True):
|
||||
step, batch_loss = model.run_step(sess, True, batch)
|
||||
loss.append(batch_loss)
|
||||
if step % FLAGS.setps_chech== 0:
|
||||
iterstion = step // steps_per_epoch + 1
|
||||
logger.info("iteration:{} step{}/{},NER loss:{:>9.6f}".format(iterstion, step%steps_per_epoch, steps_per_epoch, np.mean(loss)))
|
||||
loss = []
|
||||
|
||||
best = evaluate(sess,model,"dev", dev_manager, id_to_tag, logger)
|
||||
|
||||
if best:
|
||||
model_utils.save_model(sess, model, FLAGS.ckpt_path, logger)
|
||||
evaluate(sess, model, "test", test_manager, id_to_tag, logger)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.train:
|
||||
train()
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run(main)
|
||||
|
Binary file not shown.
|
@ -0,0 +1,432 @@
|
|||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from tensorflow.contrib.layers.python.layers import initializers
|
||||
from tensorflow.contrib.crf import crf_log_likelihood
|
||||
from tensorflow.contrib.crf import viterbi_decode
|
||||
import tensorflow.contrib.rnn as rnn
|
||||
import data_utils
|
||||
|
||||
class Model(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.lr = config['lr']
|
||||
self.word_dim = config['word_dim']
|
||||
self.lstm_dim = config['lstm_dim']
|
||||
self.seg_dim = config['seg_dim']
|
||||
self.num_tags = config['num_tags']
|
||||
self.num_words = config['num_words']
|
||||
self.num_sges = 4
|
||||
|
||||
self.global_step = tf.Variable(0, trainable=False)
|
||||
self.best_dev_f1 = tf.Variable(0.0, trainable=False)
|
||||
self.best_test_f1 = tf.Variable(0.0, trainable=False)
|
||||
self.initializer = initializers.xavier_initializer()
|
||||
|
||||
#申请占位符
|
||||
self.word_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None], name="wordInputs")
|
||||
self.seg_inputs = tf.placeholder(dtype=tf.int32, shape=[None, None], name="SegInputs")
|
||||
self.targets = tf.placeholder(dtype=tf.int32, shape=[None, None], name="Targets")
|
||||
|
||||
self.dropout = tf.placeholder(dtype=tf.float32, name="Dropout")
|
||||
|
||||
used = tf.sign(tf.abs(self.word_inputs))
|
||||
length = tf.reduce_sum(used, reduction_indices=1)
|
||||
self.lengths = tf.cast(length, tf.int32)
|
||||
self.batch_size = tf.shape(self.word_inputs)[0]
|
||||
self.num_setps = tf.shape(self.word_inputs)[-1]
|
||||
|
||||
# 模型所使用的网络选择
|
||||
self.model_type = config['model_type']
|
||||
|
||||
#idcnn模型参数
|
||||
self.layers = [
|
||||
{
|
||||
'dilation':1
|
||||
},
|
||||
{
|
||||
'dilation':1
|
||||
},
|
||||
{
|
||||
'dilation':2
|
||||
}
|
||||
]
|
||||
self.filter_width = 3
|
||||
self.num_filter = self.lstm_dim
|
||||
self.embedding_dim = self.word_dim + self.seg_dim
|
||||
self.repeat_times = 4
|
||||
self.cnn_output_width = 0
|
||||
|
||||
# embedding层单词和分词信息
|
||||
embedding = self.embedding_layer(self.word_inputs, self.seg_inputs, config)
|
||||
|
||||
|
||||
if self.model_type == 'bilstm':
|
||||
# lstm输入层
|
||||
lstm_inputs = tf.nn.dropout(embedding, self.dropout)
|
||||
|
||||
# lstm输出层
|
||||
lstm_outputs = self.biLSTM_layer(lstm_inputs, self.lstm_dim, self.lengths)
|
||||
|
||||
# 投影层
|
||||
self.logits = self.project_layer(lstm_outputs)
|
||||
|
||||
elif self.model_type == 'idcnn':
|
||||
# idcnn输入层
|
||||
idcnn_inputs = tf.nn.dropout(embedding, self.dropout)
|
||||
|
||||
# idcnn layer层
|
||||
idcnn_outputs = self.IDCNN_layer(idcnn_inputs)
|
||||
|
||||
#投影层
|
||||
self.logits = self.project_layer_idcnn(idcnn_outputs)
|
||||
|
||||
|
||||
# 损失
|
||||
self.loss = self.crf_loss_layer(self.logits, self.lengths)
|
||||
|
||||
with tf.variable_scope('optimizer'):
|
||||
optimizer = self.config['optimizer']
|
||||
if optimizer == "sgd":
|
||||
self.opt = tf.train.GradientDescentOptimizer(self.lr)
|
||||
elif optimizer == "adam":
|
||||
self.opt = tf.train.AdamOptimizer(self.lr)
|
||||
elif optimizer == "adgrad":
|
||||
self.opt = tf.train.AdagradDAOptimizer(self.lr)
|
||||
else:
|
||||
raise Exception("优化器错误")
|
||||
|
||||
grad_vars = self.opt.compute_gradients(self.loss)
|
||||
capped_grad_vars = [[tf.clip_by_value(g, -self.config['clip'], self.config['clip']),v] for g, v in grad_vars]
|
||||
|
||||
self.train_op = self.opt.apply_gradients(capped_grad_vars, self.global_step)
|
||||
|
||||
#保存模型
|
||||
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
|
||||
|
||||
def embedding_layer(self, word_inputs, seg_inputs, config, name=None):
|
||||
"""
|
||||
:param word_inputs: one-hot编码
|
||||
:param seg_inputs: 分词特征
|
||||
:param config: 配置
|
||||
:param name: 层的命名
|
||||
:return:
|
||||
"""
|
||||
embedding = []
|
||||
with tf.variable_scope("word_embedding" if not name else name), tf.device('/cpu:0'):
|
||||
self.word_lookup = tf.get_variable(
|
||||
name = "word_embedding",
|
||||
shape = [self.num_words, self.word_dim],
|
||||
initializer = self.initializer
|
||||
)
|
||||
embedding.append(tf.nn.embedding_lookup(self.word_lookup, word_inputs))
|
||||
|
||||
if config['seg_dim']:
|
||||
with tf.variable_scope("seg_embedding"), tf.device('/cpu:0'):
|
||||
self.seg_lookup = tf.get_variable(
|
||||
name = "seg_embedding",
|
||||
shape = [self.num_sges, self.seg_dim],
|
||||
initializer = self.initializer
|
||||
)
|
||||
embedding.append(tf.nn.embedding_lookup(self.seg_lookup, seg_inputs))
|
||||
embed = tf.concat(embedding, axis=-1)
|
||||
return embed
|
||||
|
||||
def biLSTM_layer(self, lstm_inputs, lstm_dim, lengths,name=None):
|
||||
"""
|
||||
:param lstm_inputs: [batch_size, num_steps, emb_size]
|
||||
:param lstm_dim:
|
||||
:param name:
|
||||
:return: [batch_size, num_steps, 2*lstm_dim]
|
||||
"""
|
||||
with tf.variable_scope("word_biLSTM" if not name else name):
|
||||
lstm_cell = {}
|
||||
for direction in ['forward', 'backward']:
|
||||
with tf.variable_scope(direction):
|
||||
lstm_cell[direction] = rnn.CoupledInputForgetGateLSTMCell(
|
||||
lstm_dim,
|
||||
use_peepholes=True,
|
||||
initializer = self.initializer,
|
||||
state_is_tuple=True
|
||||
)
|
||||
outputs, final_status = tf.nn.bidirectional_dynamic_rnn(
|
||||
lstm_cell['forward'],
|
||||
lstm_cell['backward'],
|
||||
lstm_inputs,
|
||||
dtype = tf.float32,
|
||||
|
||||
sequence_length = lengths
|
||||
)
|
||||
|
||||
return tf.concat(outputs, axis=2)
|
||||
|
||||
def IDCNN_layer(self, idcnn_inputs, name = None):
|
||||
"""
|
||||
:param idcnn_inputs: [batch_size, num_steps, emb_size]
|
||||
:param name:
|
||||
:return: [batch_size, num_steps, cnn_output_withd]
|
||||
"""
|
||||
idcnn_inputs = tf.expand_dims(idcnn_inputs, 1)
|
||||
reuse = False
|
||||
if not self.config['is_train']:
|
||||
reuse = True
|
||||
with tf.variable_scope('idcnn' if not name else name):
|
||||
shape = [1, self.filter_width, self.embedding_dim, self.num_filter]
|
||||
|
||||
filter_weights = tf.get_variable(
|
||||
"idcnn_filter",
|
||||
shape=[1, self.filter_width, self.embedding_dim, self.num_filter],
|
||||
initializer=self.initializer
|
||||
)
|
||||
|
||||
layer_input = tf.nn.conv2d(
|
||||
idcnn_inputs,
|
||||
filter_weights,
|
||||
strides=[1,1,1,1],
|
||||
padding='SAME',
|
||||
name='init_layer'
|
||||
)
|
||||
|
||||
finalOutFromLayers = []
|
||||
totalWidthForLastDim = 0
|
||||
for j in range(self.repeat_times):
|
||||
for i in range(len(self.layers)):
|
||||
dilation = self.layers[i]['dilation']
|
||||
isLast = True if i == (len(self.layers) - 1) else False
|
||||
with tf.variable_scope('conv-layer-%d' % i, reuse = tf.AUTO_REUSE):
|
||||
w = tf.get_variable(
|
||||
'fliter_w',
|
||||
shape=[1, self.filter_width, self.num_filter, self.num_filter],
|
||||
initializer=self.initializer
|
||||
)
|
||||
|
||||
b = tf.get_variable('filterB', shape=[self.num_filter])
|
||||
|
||||
conv = tf.nn.atrous_conv2d(
|
||||
layer_input,
|
||||
w,
|
||||
rate=dilation,
|
||||
padding="SAME"
|
||||
)
|
||||
|
||||
conv = tf.nn.bias_add(conv, b)
|
||||
|
||||
conv = tf.nn.relu(conv)
|
||||
|
||||
if isLast:
|
||||
finalOutFromLayers.append(conv)
|
||||
totalWidthForLastDim = totalWidthForLastDim + self.num_filter
|
||||
layer_input = conv
|
||||
|
||||
finalOut = tf.concat(axis=3, values=finalOutFromLayers)
|
||||
keepProb = 1.0 if reuse else 0.5
|
||||
finalOut = tf.nn.dropout(finalOut, keepProb)
|
||||
|
||||
finalOut = tf.squeeze(finalOut, [1])
|
||||
finalOut = tf.reshape(finalOut, [-1, totalWidthForLastDim])
|
||||
self.cnn_output_width = totalWidthForLastDim
|
||||
return finalOut
|
||||
|
||||
def project_layer_idcnn(self, idcnn_outputs, name=None):
|
||||
"""
|
||||
:param idcnn_outputs: [batch_size, num_steps, emb_size]
|
||||
:param name:
|
||||
:return: [batch_size, num_steps, emb_size]
|
||||
"""
|
||||
with tf.variable_scope('idcnn_project' if not name else name):
|
||||
|
||||
with tf.variable_scope('idcnn_logits'):
|
||||
W = tf.get_variable(
|
||||
"W",
|
||||
shape=[self.cnn_output_width, self.num_tags],
|
||||
dtype=tf.float32,
|
||||
initializer=self.initializer
|
||||
)
|
||||
|
||||
b = tf.get_variable(
|
||||
"b",
|
||||
initializer=tf.constant(0.001, shape=[self.num_tags])
|
||||
)
|
||||
|
||||
pred = tf.nn.xw_plus_b(idcnn_outputs, W, b)
|
||||
|
||||
return tf.reshape(pred, [-1, self.num_setps, self.num_tags])
|
||||
|
||||
|
||||
def project_layer(self, lstm_outputs, name=None):
|
||||
"""
|
||||
:param lstm_outputs: [batch_size, num_steps, emb_size]
|
||||
:param name:
|
||||
:return: [btch_size,num_steps, num_tags]
|
||||
"""
|
||||
with tf.variable_scope('project_layer' if not name else name):
|
||||
with tf.variable_scope('hidden_layer'):
|
||||
W = tf.get_variable(
|
||||
"W",
|
||||
shape=[self.lstm_dim*2, self.lstm_dim],
|
||||
dtype=tf.float32,
|
||||
initializer=self.initializer
|
||||
)
|
||||
b = tf.get_variable(
|
||||
"b",
|
||||
shape=[self.lstm_dim],
|
||||
dtype=tf.float32,
|
||||
initializer=tf.zeros_initializer()
|
||||
)
|
||||
out_put = tf.reshape(lstm_outputs, shape=[-1, self.lstm_dim*2])
|
||||
hidden = tf.tanh(tf.nn.xw_plus_b(out_put, W, b))
|
||||
|
||||
with tf.variable_scope('logits'):
|
||||
W = tf.get_variable(
|
||||
"W",
|
||||
shape=[self.lstm_dim, self.num_tags],
|
||||
dtype=tf.float32,
|
||||
initializer=self.initializer
|
||||
)
|
||||
b = tf.get_variable(
|
||||
"b",
|
||||
shape=[self.num_tags],
|
||||
dtype=tf.float32,
|
||||
initializer=tf.zeros_initializer()
|
||||
)
|
||||
|
||||
pred = tf.nn.xw_plus_b(hidden,W, b)
|
||||
return tf.reshape(pred, [-1, self.num_setps, self.num_tags])
|
||||
|
||||
def crf_loss_layer(self, project_logits, lenghts, name=None):
|
||||
"""
|
||||
:param project_logits: [1, num_steps, num_tages
|
||||
:param lenghts:
|
||||
:param name:
|
||||
:return: scalar loss
|
||||
"""
|
||||
with tf.variable_scope('crf_loss' if not name else name):
|
||||
small_value = -10000.0
|
||||
start_logits = tf.concat(
|
||||
[
|
||||
small_value *
|
||||
tf.ones(shape= [self.batch_size, 1, self.num_tags]),
|
||||
tf.zeros(shape=[self.batch_size, 1, 1])
|
||||
],
|
||||
axis = -1
|
||||
)
|
||||
|
||||
pad_logits = tf.cast(
|
||||
small_value *
|
||||
tf.ones(shape=[self.batch_size, self.num_setps, 1]),
|
||||
dtype=tf.float32
|
||||
)
|
||||
|
||||
logits = tf.concat(
|
||||
[project_logits, pad_logits],
|
||||
axis=-1
|
||||
)
|
||||
logits = tf.concat(
|
||||
[start_logits, logits],
|
||||
axis=1
|
||||
)
|
||||
|
||||
targets = tf.concat(
|
||||
[tf.cast(
|
||||
self.num_tags*tf.ones([self.batch_size, 1]) ,
|
||||
tf.int32
|
||||
),
|
||||
self.targets
|
||||
]
|
||||
,
|
||||
axis = -1
|
||||
)
|
||||
|
||||
self.trans = tf.get_variable(
|
||||
"transitions",
|
||||
shape=[self.num_tags+1, self.num_tags+1],
|
||||
initializer=self.initializer
|
||||
)
|
||||
|
||||
log_likehood, self.trans = crf_log_likelihood(
|
||||
inputs=logits,
|
||||
tag_indices=targets,
|
||||
transition_params=self.trans,
|
||||
sequence_lengths=lenghts +1
|
||||
)
|
||||
return tf.reduce_mean(-log_likehood)
|
||||
|
||||
def decode(self, logits, lengths, matrix):
|
||||
"""
|
||||
:param logits: [batch_size,num_steps, num_tags
|
||||
:param lengths:
|
||||
:param matrix:
|
||||
:return:
|
||||
"""
|
||||
paths = []
|
||||
small = -1000.0
|
||||
start = np.asarray([[small]*self.num_tags + [0]])
|
||||
for score, length in zip(logits, lengths):
|
||||
score = score[:length]
|
||||
pad = small * np.ones([length,1])
|
||||
logits = np.concatenate([score, pad], axis=1)
|
||||
logits = np.concatenate([start, logits], axis=0)
|
||||
path,_ = viterbi_decode(logits, matrix)
|
||||
|
||||
paths.append(path[1:])
|
||||
return paths
|
||||
|
||||
def create_feed_dict(self, is_train, batch):
|
||||
"""
|
||||
:param is_train:
|
||||
:param batch:
|
||||
:return:
|
||||
"""
|
||||
_, words, segs, tags = batch
|
||||
feed_dict = {
|
||||
self.word_inputs:np.asarray(words),
|
||||
self.seg_inputs:np.asarray(segs),
|
||||
self.dropout:1.0
|
||||
}
|
||||
|
||||
if is_train:
|
||||
feed_dict[self.targets] = np.asarray(tags)
|
||||
feed_dict[self.dropout] = self.config['dropout_keep']
|
||||
return feed_dict
|
||||
|
||||
def run_step(self, sess, is_train, batch):
|
||||
"""
|
||||
:param sess:
|
||||
:param is_train:
|
||||
:param bath:
|
||||
:return:
|
||||
"""
|
||||
feed_dict = self.create_feed_dict(is_train, batch)
|
||||
if is_train:
|
||||
global_step, loss, _= sess.run(
|
||||
[self.global_step, self.loss, self.train_op], feed_dict
|
||||
)
|
||||
return global_step, loss
|
||||
else:
|
||||
lengths, logits = sess.run([self.lengths, self.logits], feed_dict)
|
||||
return lengths, logits
|
||||
|
||||
def evaluate(self, sess, data_manager, id_to_tag):
|
||||
"""
|
||||
:param sess:
|
||||
:param data_manager:
|
||||
:param id_to_tag:
|
||||
:return:
|
||||
"""
|
||||
results = []
|
||||
trans = self.trans.eval()
|
||||
for batch in data_manager.iter_batch():
|
||||
strings = batch[0]
|
||||
tags = batch[-1]
|
||||
lengths , logits = self.run_step(sess, False, batch)
|
||||
batch_paths = self.decode(logits, lengths, trans)
|
||||
for i in range(len(strings)):
|
||||
result = []
|
||||
string = strings[i][:lengths[i]]
|
||||
gold = data_utils.bioes_to_bio([id_to_tag[int(x)] for x in tags[i][:lengths[i]]])
|
||||
pred = data_utils.bioes_to_bio([id_to_tag[int(x)] for x in batch_paths[i][:lengths[i]]])
|
||||
for char, gold, pred in zip(string, gold, pred):
|
||||
result.append(" ".join([char, gold, pred]))
|
||||
results.append(result)
|
||||
return results
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from conlleval import return_report
|
||||
import codecs
|
||||
import tensorflow as tf
|
||||
|
||||
def get_logger(log_file):
|
||||
"""
|
||||
定义日志方法
|
||||
:param log_file:
|
||||
:return:
|
||||
"""
|
||||
# 创建一个logging的实例 logger
|
||||
logger = logging.getLogger(log_file)
|
||||
# 设置logger的全局日志级别为DEBUG
|
||||
logger.setLevel(logging.DEBUG)
|
||||
# 创建一个日志文件的handler,并且设置日志级别为DEBUG
|
||||
fh = logging.FileHandler(log_file)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
# 创建一个控制台的handler,并设置日志级别为DEBUG
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.INFO)
|
||||
# 设置日志格式
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
# add formatter to ch and fh
|
||||
ch.setFormatter(formatter)
|
||||
fh.setFormatter(formatter)
|
||||
# add ch and fh to logger
|
||||
logger.addHandler(ch)
|
||||
logger.addHandler(fh)
|
||||
return logger
|
||||
|
||||
|
||||
def config_model(FLAGS, word_to_id, tag_to_id):
|
||||
config = OrderedDict()
|
||||
config['num_words'] = len(word_to_id)
|
||||
config['word_dim'] = FLAGS.word_dim
|
||||
config['num_tags'] = len(tag_to_id)
|
||||
config['seg_dim'] = FLAGS.seg_dim
|
||||
config['lstm_dim'] = FLAGS.lstm_dim
|
||||
config['batch_size'] = FLAGS.batch_size
|
||||
config['optimizer'] = FLAGS.optimizer
|
||||
config['emb_file'] = FLAGS.emb_file
|
||||
|
||||
config['clip'] = FLAGS.clip
|
||||
config['dropout_keep'] = 1.0 - FLAGS.dropout
|
||||
config['optimizer'] = FLAGS.optimizer
|
||||
config['lr'] = FLAGS.lr
|
||||
config['tag_schema'] = FLAGS.tag_schema
|
||||
config['pre_emb'] = FLAGS.pre_emb
|
||||
config['model_type'] = FLAGS.model_type
|
||||
config['is_train'] = FLAGS.train
|
||||
return config
|
||||
|
||||
def make_path(params):
|
||||
"""
|
||||
创建文件夹
|
||||
:param params:
|
||||
:return:
|
||||
"""
|
||||
if not os.path.isdir(params.result_path):
|
||||
os.makedirs(params.result_path)
|
||||
if not os.path.isdir(params.ckpt_path):
|
||||
os.makedirs(params.ckpt_path)
|
||||
if not os.path.isdir('log'):
|
||||
os.makedirs('log')
|
||||
|
||||
def save_config(config, config_file):
|
||||
"""
|
||||
保存配置文件
|
||||
:param config:
|
||||
:param config_path:
|
||||
:return:
|
||||
"""
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def load_config(config_file):
|
||||
"""
|
||||
加载配置文件
|
||||
:param config_file:
|
||||
:return:
|
||||
"""
|
||||
with open(config_file, encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def print_config(config, logger):
|
||||
"""
|
||||
打印模型参数
|
||||
:param config:
|
||||
:param logger:
|
||||
:return:
|
||||
"""
|
||||
for k, v in config.items():
|
||||
logger.info("{}:\t{}".format(k.ljust(15), v))
|
||||
|
||||
def create(sess, Model, ckpt_path, load_word2vec, config, id_to_word, logger):
|
||||
"""
|
||||
:param sess:
|
||||
:param Model:
|
||||
:param ckpt_path:
|
||||
:param load_word2vec:
|
||||
:param config:
|
||||
:param id_to_word:
|
||||
:param logger:
|
||||
:return:
|
||||
"""
|
||||
model = Model(config)
|
||||
|
||||
ckpt = tf.train.get_checkpoint_state(ckpt_path)
|
||||
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
|
||||
logger("读取模型参数,从%s" % ckpt.model_checkpoint_path)
|
||||
model.saver.restore(sess, ckpt.model_checkpoint_path)
|
||||
|
||||
else:
|
||||
logger.info("重新训练模型")
|
||||
sess.run(tf.global_variables_initializer())
|
||||
if config['pre_emb']:
|
||||
emb_weights = sess.run(model.word_lookup.read_value())
|
||||
emb_weights = load_word2vec(config['emb_file'], id_to_word, config['word_dim'], emb_weights)
|
||||
sess.run(model.word_lookup.assign(emb_weights))
|
||||
logger.info("加载词向量成功")
|
||||
return model
|
||||
|
||||
def test_ner(results, path):
|
||||
"""
|
||||
:param results:
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
output_file = os.path.join(path, 'ner_predict.utf8')
|
||||
with codecs.open(output_file, "w", encoding="utf-8") as f_write:
|
||||
to_write = []
|
||||
for line in results:
|
||||
for iner_line in line:
|
||||
to_write.append(iner_line + "\n")
|
||||
to_write.append("\n")
|
||||
f_write.writelines(to_write)
|
||||
eval_lines = return_report(output_file)
|
||||
return eval_lines
|
||||
|
||||
|
||||
def save_model(sess, model, path, logger):
|
||||
"""
|
||||
:param sess:
|
||||
:param model:
|
||||
:param path:
|
||||
:param logger:
|
||||
:return:
|
||||
"""
|
||||
checkpoint_path = os.path.join(path, "ner.ckpt")
|
||||
model.saver.save(sess, checkpoint_path)
|
||||
logger.info('模型已经保存')
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
model_checkpoint_path: "ner.ckpt"
|
||||
all_model_checkpoint_paths: "ner.ckpt"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue