2012年10月17日 星期三

Refactoring - TWSE Statistics

首先把共用的部分拆出來。Copy/Paste 不是好的 reuse 技術,倘若程式有錯,常常改了A忘了改 A。這樣就很麻煩。與其做兩次以上功,不如抽出來。


./src/common/sourcing_twse.py

# coding: big5

import csv
import logging
import os
import shutil
import sqlite3

from datetime import date
from datetime import datetime

class SourcingTwse():

    def __init__(self):
        self.LOGGER = logging.getLogger()
        self.URL_TEMPLATE = ''
        self.DATES = []
        self.ZIP_DIR = ''
        self.XLS_DIR = ''
        self.CSV_DIR = ''
        self.DB_FILE = './db/stocktotal.db'
        self.SQL_INSERT = ''

    def init_dates(self, begin_date, end_date):
        begin = datetime.strptime(begin_date, '%Y-%m-%d')
        end = datetime.strptime(end_date, '%Y-%m-%d')
        monthly_begin = 12 * begin.year + begin.month - 1
        monthly_end = 12 * end.year + end.month
        for monthly in range(monthly_begin, monthly_end):
            year, month = divmod(monthly, 12)
            self.DATES.append(date(year, month + 1, 1))
         
    def source_url_to_zip(self, dest_dir):
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        for date in self.DATES:
            url = self.URL_TEMPLATE % date.strftime('%Y%m')
            dest_file = self.get_filename(dest_dir, date, 'zip')
            self.__wget(url, dest_file)

    def source_zip_to_xls(self, src_dir, dest_dir):
        assert os.path.isdir(src_dir)
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        for date in self.DATES:
            src_file = self.get_filename(src_dir, date, 'zip')
            dest_file = self.get_filename(dest_dir, date, 'xls')
            self.source_zip_to_xls_single(src_file, dest_dir, dest_file)
 
    def source_zip_to_xls_single(self, src_file, dest_dir, dest_file):
        assert os.path.isfile(src_file)
        assert os.path.isdir(dest_dir)

        sevenzip_output_dir = os.path.join(dest_dir, 'sevenzip_output_dir')
        self.__sevenzip_extract(src_file, sevenzip_output_dir)
        if not os.path.exists(sevenzip_output_dir):
            self.LOGGER.info('''%s => Failure to extract''' % src_file)
            return

        file_list = os.listdir(sevenzip_output_dir)
        assert len(file_list) is 1
        sevenzip_output_file = os.path.join(sevenzip_output_dir, file_list[0])
        shutil.copy(sevenzip_output_file, dest_file)
        shutil.rmtree(sevenzip_output_dir)
     
    def source_csv_to_sqlite(self, src_dir, dest_db, sql_insert):
        assert os.path.isdir(src_dir)
        assert os.path.isfile(dest_db)
        for date in self.DATES:
            src_file = self.get_filename(src_dir, date, 'csv')
            if os.path.isfile(src_file):
                self.source_csv_to_sqlite_single(src_file, dest_db, sql_insert)
         
    def source_csv_to_sqlite_single(self, src_file, dest_db, sql_insert):
        self.LOGGER.debug('''%s => %s''' % (src_file, dest_db))
        fd = open(src_file, 'r')
        csv_reader = csv.reader(fd)
        conn = sqlite3.connect(dest_db)
        cursor = conn.cursor()
        for row in csv_reader:
            cursor.execute(sql_insert, row)
            self.LOGGER.debug(row)
        conn.commit()
        cursor.close()
        conn.close()
        fd.close()

    def get_filename(self, src_dir, date, ext):
        return os.path.join(src_dir, date.strftime('%Y-%m') + '.' + ext)
     
    def __wget(self, url, dest_file):
        wget = os.path.abspath('./src/thirdparty/wget/wget.exe')
        assert os.path.isfile(wget)
        wget_cmdline = '''%s -N \"%s\" --waitretry=3 -O \"%s\"''' % (wget, url, dest_file)
        os.system(wget_cmdline)

    def __sevenzip_extract(self, src_file, dest_dir):
        sevenzip = os.path.abspath('./src/thirdparty/sevenzip/7z.exe')
        assert os.path.isfile(sevenzip)
        sevenzip_cmdline = '''%s e %s -y -o%s''' % (sevenzip, src_file, dest_dir)
        os.system(sevenzip_cmdline)


接著我可以專心處理 EXCEL => SQLite 的中繼檔 ((CSV))。


./src/market_statistics/sourcing.py

# coding: big5

import csv
import logging
import os
import xlrd

from datetime import date
from datetime import datetime

from ..common import sourcing_twse

class Sourcing(sourcing_twse.SourcingTwse):

    def __init__(self):
        self.LOGGER = logging.getLogger()
        self.URL_TEMPLATE = '''http://www.twse.com.tw/ch/inc/download.php?l1=Securities+Trading+Monthly+Statistics&l2=Statistics+of+Securities+Market&url=/ch/statistics/download/02/001/%s_C02001.zip'''
        self.DATES = []
        self.ZIP_DIR = '''./dataset/market_statistics/zip/'''
        self.XLS_DIR = '''./dataset/market_statistics/xls/'''
        self.CSV_DIR = '''./dataset/market_statistics/csv/'''
        self.DB_FILE = './db/stocktotal.db'
        self.SQL_INSERT = '''insert or ignore into MarketStatistics(
                report_date,
                activity_date,
                report_type,
                total_trading_value,
                listed_co_number,
                capital_issued,
                total_listed_shares,
                market_capitalization,
                trading_volume,
                trading_value,
                trans_number,
                average_taiex,
                volume_turnover_rate,
                per,
                dividend_yield,
                pbr,
                trading_days
            ) values(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)'''

    def source(self, begin_date, end_date):
        sourcing_twse.SourcingTwse.init_dates(self, begin_date, end_date)
        sourcing_twse.SourcingTwse.source_url_to_zip(self, self.ZIP_DIR)
        sourcing_twse.SourcingTwse.source_zip_to_xls(self, self.ZIP_DIR, self.XLS_DIR)
        self.source_xls_to_csv(self.XLS_DIR, self.CSV_DIR)
        sourcing_twse.SourcingTwse.source_csv_to_sqlite(self, self.CSV_DIR, self.DB_FILE, self.SQL_INSERT)
 
    def source_xls_to_csv(self, src_dir, dest_dir):
        assert os.path.isdir(src_dir)
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        for date in reversed(self.DATES):
            src_file = sourcing_twse.SourcingTwse.get_filename(self, src_dir, date, 'xls')
            self.source_xls_to_csv_single(src_file, dest_dir, date)  

    def source_xls_to_csv_single(self, src_file, dest_dir, date):
        assert os.path.isfile(src_file)
        assert os.path.isdir(dest_dir)
        self.__source_v2_xls_to_csv_single(src_file, dest_dir, date)
        self.__source_v1_xls_to_csv_single(src_file, dest_dir, date)

    def __source_v2_xls_to_csv_single(self, src_file, dest_dir, date):
        if date < datetime(2003, 6, 1).date():
            return
        book = xlrd.open_workbook(src_file)
        sheet = book.sheet_by_index(0)
        assert sheet.ncols is 15
        assert sheet.cell(12, 14).value == 'Days'
        assert sheet.cell(12, 0).value.strip() == 'Month'
     
        dest_file = sourcing_twse.SourcingTwse.get_filename(self, dest_dir, date, 'csv')
        fd = open(dest_file, 'w', newline='')
        csv_writer = csv.writer(fd)
        for r in self.__build_sheet_records(sheet, 13):
            r = [date.strftime('%Y-%m-%d')] + r
            r = self.__remove_comment_mark(r)
            assert len(r) is 17
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        fd.close()
             
    def __source_v1_xls_to_csv_single(self, src_file, dest_dir, date):
        if date >= datetime(2003, 6, 1).date() or date <= datetime(2000, 9, 1).date():
            return
        book = xlrd.open_workbook(src_file)
        main_sheet = book.sheet_by_index(0)
        assert main_sheet.ncols is 12
     
        if date > datetime(2001, 6, 1).date():
            assert main_sheet.cell(12, 0).value.strip() == 'Month'
        elif date > datetime(2000, 9, 1).date():
            assert main_sheet.cell(11, 0).value.strip() == 'Month'
            assert main_sheet.cell(12, 0).value.strip() == ''
        main_records = self.__build_sheet_records(main_sheet, 13)
     
        rest_sheet = book.sheet_by_index(1)
        assert rest_sheet.ncols is 13
        assert rest_sheet.cell(10, 0).value.strip() == 'Month'
        rest_records = self.__build_sheet_records(rest_sheet, 11)

        assert len(main_records) == len(rest_records)
     
        dest_file = sourcing_twse.SourcingTwse.get_filename(self, dest_dir, date, 'csv')
        fd = open(dest_file, 'w', newline='')
        csv_writer = csv.writer(fd)
        for i in range(len(main_records)):
            assert len(main_records[i]) is 13
            assert len(rest_records[i]) is 14
            assert main_records[i][0] == rest_records[i][0]
            assert main_records[i][1] == rest_records[i][1]
            r = [date.strftime('%Y-%m-%d')] + \
                    main_records[i][:-2] + rest_records[i][2:6] + rest_records[i][-2:-1]
            r = self.__remove_comment_mark(r)
            assert len(r) is 17
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        fd.close()
     
    def __build_sheet_records(self, sheet, begin_row):
        rv = []
     
        monthly_curr_year = ''
        for curr_row in range(begin_row, sheet.nrows):
            r = sheet.row_values(curr_row)
            first_cell = r[0].strip()
         
            if first_cell.startswith('註'): # Check footer.
                break
            if first_cell.endswith(')月'): # Ignore this year summary because it is partial.
                continue
            if first_cell.endswith(')'): # Check if yearly record. Example: 93(2004)
                curr_date = '''%s-01-01''' % first_cell[first_cell.index('(')+1 : -1]
                sheet_record = [curr_date, 'yearly'] + r[1:]
                rv.append(sheet_record)
            if first_cell.endswith('月'): # Check if monthly record. Example: 95年  1月
                curr_month = 0
                if '年' in first_cell:
                    monthly_curr_year = int(first_cell[:first_cell.index('年')]) + 1911
                    curr_month = int(first_cell[first_cell.index('年')+1 : first_cell.index('月')])
                else:
                    curr_month = int(first_cell[:first_cell.index('月')])
                curr_date = '''%s-%02d-01''' % (monthly_curr_year, curr_month)
                sheet_record = [curr_date, 'monthly'] + r[1:]
                rv.append(sheet_record)
        return rv
     
    def __remove_comment_mark(self, csv_record):
        rv = csv_record[:3]
        for i in range(3, len(csv_record)):
            value = csv_record[i]
            try:
                float(value)
                rv.append(value)
            except ValueError:
                fixed_value = value[value.rindex(' ')+ 1 :].replace(',', '')
                float(fixed_value)
                rv.append(fixed_value)
        return rv


./src/listed_co_statistics/sourcing.py

# coding: big5

import csv
import logging
import os
import xlrd

from datetime import date
from datetime import datetime

from ..common import sourcing_twse
from ..common import str_util as str_util

class Sourcing(sourcing_twse.SourcingTwse):

    def __init__(self):
        self.LOGGER = logging.getLogger()
        self.URL_TEMPLATE = '''http://www.twse.com.tw/ch/inc/download.php?l1=Listed+Companies+Monthly+Statistics&l2=P%%2FE+Ratio+%%26+Yield+of+Listed+Stocks&url=/ch/statistics/download/04/001/%s_C04001.zip'''
        self.DATES = []
        self.ZIP_DIR = '''./dataset/listed_co_statistics/zip/'''
        self.XLS_DIR = '''./dataset/listed_co_statistics/xls/'''
        self.CSV_DIR = '''./dataset/listed_co_statistics/csv/'''
        self.DB_FILE = './db/stocktotal.db'
        self.SQL_INSERT = '''insert or ignore into ListedCoStatistics(
                report_date,
                stock_code,
                latest_price,
                per,
                yield,
                pbr
            ) values(?, ?, ?, ?, ?, ?)'''

    def source(self, begin_date, end_date):
        sourcing_twse.SourcingTwse.init_dates(self, begin_date, end_date)
        sourcing_twse.SourcingTwse.source_url_to_zip(self, self.ZIP_DIR)
        sourcing_twse.SourcingTwse.source_zip_to_xls(self, self.ZIP_DIR, self.XLS_DIR)
        self.source_xls_to_csv(self.XLS_DIR, self.CSV_DIR)
        sourcing_twse.SourcingTwse.source_csv_to_sqlite(self, self.CSV_DIR, self.DB_FILE, self.SQL_INSERT)
 
    def source_xls_to_csv(self, src_dir, dest_dir):
        assert os.path.isdir(src_dir)
        if not os.path.exists(dest_dir):
            os.makedirs(dest_dir)
        for date in reversed(self.DATES):
            src_file = sourcing_twse.SourcingTwse.get_filename(self, src_dir, date, 'xls')
            self.source_xls_to_csv_single(src_file, dest_dir, date)  

    # CSV fields should contains: Report Date, Stock's Code, Latest Price, PER, Yield, PBR
    def source_xls_to_csv_single(self, src_file, dest_dir, date):
        assert os.path.isfile(src_file)
        assert os.path.isdir(dest_dir)
        self.__source_v3_xls_to_csv_single(src_file, dest_dir, date)
        self.__source_v2_xls_to_csv_single(src_file, dest_dir, date)
        self.__source_v1_xls_to_csv_single(src_file, dest_dir, date)
     
    def __source_v3_xls_to_csv_single(self, src_file, dest_dir, date):
        if date < datetime(2007, 4, 1).date():
            return
     
        book = xlrd.open_workbook(src_file)
        sheet = book.sheet_by_index(0)
        assert sheet.ncols is 10
        assert sheet.cell(4, 0).value.strip() == 'Code & Name'
        assert sheet.cell(4, 8).value.strip() in ('PBR', 'PBR')
     
        dest_file = sourcing_twse.SourcingTwse.get_filename(self, dest_dir, date, 'csv')
        fd = open(dest_file, 'w', newline='')
        csv_writer = csv.writer(fd)
        for r in self.__build_sheet_records(sheet, 0, 5):
            r = [date.strftime('%Y-%m-%d')] + r
            assert len(r) is 6
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        fd.close()
     
    def __source_v2_xls_to_csv_single(self, src_file, dest_dir, date):
        if date >= datetime(2007, 4, 1).date() or date < datetime(2000, 9, 1).date():
            return
             
        book = xlrd.open_workbook(src_file)
        sheet = book.sheet_by_index(0)
        assert sheet.ncols is 21
        assert sheet.cell(4, 0).value.strip() in ('Code & Name', 'CODE & NAME')
        assert sheet.cell(4, 11).value.strip() in ('Code & Name', 'CODE & NAME')
        assert sheet.cell(4, 8).value.strip() in ('PBR', 'PBR')
        assert sheet.cell(4, 19).value.strip() in ('PBR', 'PBR')
     
        dest_file = sourcing_twse.SourcingTwse.get_filename(self, dest_dir, date, 'csv')
        fd = open(dest_file, 'w', newline='')
        csv_writer = csv.writer(fd)
        for r in self.__build_sheet_records(sheet, 0, 5):
            r = [date.strftime('%Y-%m-%d')] + r
            assert len(r) is 6
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        for r in self.__build_sheet_records(sheet, 11, 5):
            r = [date.strftime('%Y-%m-%d')] + r
            assert len(r) is 6
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        fd.close()
     
    def __source_v1_xls_to_csv_single(self, src_file, dest_dir, date):
        if date >= datetime(2000, 9, 1).date():
            return

        book = xlrd.open_workbook(src_file)
        sheet = book.sheet_by_index(0)
        if date == datetime(2000, 5, 1).date():
            header_last_row = 5
        elif date <= datetime(1999, 7, 1).date():
            header_last_row = 8
        else:
            header_last_row = 4
     
        assert sheet.ncols in (17, 11)
        assert sheet.cell(header_last_row, 0).value.strip() in ('Code & Name', 'CODE & NAME')
        assert sheet.cell(header_last_row, 6).value.strip() in ('Code & Name', 'CODE & NAME')
        assert sheet.cell(header_last_row, 4).value.strip() in ('PBR', 'PBR')
        assert sheet.cell(header_last_row, 10).value.strip() in ('PBR', 'PBR')

        dest_file = sourcing_twse.SourcingTwse.get_filename(self, dest_dir, date, 'csv')
        fd = open(dest_file, 'w', newline='')
        csv_writer = csv.writer(fd)
        begin_row = header_last_row + 1
        for r in self.__build_bad_sheet_records(sheet, 0, begin_row):
            r = [date.strftime('%Y-%m-%d')] + r
            assert len(r) is 6
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        for r in self.__build_bad_sheet_records(sheet, 6, begin_row):
            r = [date.strftime('%Y-%m-%d')] + r
            assert len(r) is 6
            csv_writer.writerow(r)
            self.LOGGER.debug('''%s => %s''' % (r, dest_file))
        fd.close()
     
    def __build_sheet_records(self, sheet, begin_col, begin_row):
        for curr_row in range(begin_row, sheet.nrows):
            r = sheet.row_values(curr_row)
            first_cell = r[begin_col]
     
            if r[begin_col] == '':
                continue
            if r[begin_col + 3] == '' and r[begin_col + 5] == '' \
                    and r[begin_col + 7] == '' and r[begin_col + 9] == '':
                continue
            if isinstance(first_cell, float):
                first_cell = int(first_cell)
            elif isinstance(first_cell, str):
                first_cell = first_cell.replace(' ','')
            yield [first_cell, r[begin_col + 3], r[begin_col + 5], r[begin_col + 7], r[begin_col + 9]]

    def __build_bad_sheet_records(self, sheet, begin_col, begin_row):
        for curr_row in range(begin_row, sheet.nrows):
            r = sheet.row_values(curr_row)
            stock_code = self.__fix_stock_code(r[begin_col])
            latest_price = self.__fix_real_number(r[begin_col + 1])
            per = self.__fix_real_number(r[begin_col + 2])
            dividend_yield = self.__fix_real_number(r[begin_col + 3])
            pbr = self.__fix_real_number(r[begin_col + 4])

            if stock_code == '':
                continue
            if latest_price == '' and per == '' and dividend_yield == '' and pbr == '':
                continue
            yield [stock_code, latest_price, per, dividend_yield, pbr]
 
    def __fix_stock_code(self, bad_stock_code):
        space_removed = bad_stock_code.replace(' ','')
        stock_code = space_removed[0:4]
        if stock_code.isdigit(): # Quickly get possible stock_code
            return stock_code
        return space_removed

    def __fix_real_number(self, bad_str):
        if str_util.is_float(bad_str):
            return float(bad_str)
        assert str_util.is_str(bad_str)
        splitted = bad_str.split()
        for test_str in splitted:
            if str_util.is_float(test_str):
                return float(test_str)
        return ''


./src/common/str_util.py

def is_float(test_str):
    try:
        float(test_str)
        return True
    except ValueError:
        return False

def is_str(test_str):
    try:
        str(test_str)
        return True
    except ValueError:
        return False


./src/common/date_util.py

import datetime

def get_last_month():
    today = datetime.date.today()
    first = datetime.date(day=1, month=today.month, year=today.year)
    last_month = first - datetime.timedelta(days=1)
    return datetime.date(day=1, month=last_month.month, year=last_month.year)

def get_this_month():
    today = datetime.date.today()
    return datetime.date(day=1, month=today.month, year=today.year)

def get_yesterday():
    return datetime.date.today() - datetime.timedelta(days=1)


SQLite3 Schema:

create table if not exists MarketStatistics
(
    creation_dt datetime default current_timestamp,
    report_date datetime not null,
    activity_date datetime not null,
    report_type text not null,
    total_trading_value real,
    listed_co_number real,
    capital_issued real,
    total_listed_shares real,
    market_capitalization real,
    trading_volume real,
    trading_value real,
    trans_number real,
    average_taiex real,
    volume_turnover_rate real,
    per real,
    dividend_yield real,
    pbr real,
    trading_days int,
    unique (report_date, activity_date, report_type) on conflict ignore
);

create table if not exists ListedCoStatistics
(
    creation_dt datetime default current_timestamp,
    report_date datetime not null,
    stock_code text not null,
    latest_price real,
    per real,
    yield real,
    pbr real,
    unique (report_date, stock_code) on conflict ignore
);


最後是 sourcing 操作界面,原來台塑也有天天都便宜的時候。

./source_market_statistics.py

import logging
import sys

import src.market_statistics.sourcing as sourcing
import src.common.logger as logger
import src.common.date_util as date_util

FIRST_DAY = '1999-01-01'

def source_all():
    logger.config_root(level=logging.DEBUG)
    last_month = str(date_util.get_last_month())
    s = sourcing.Sourcing()
    s.source(FIRST_DAY, last_month)

def source_last_month():
    logger.config_root(level=logging.DEBUG)
    last_month = str(date_util.get_last_month())
    s = sourcing.Sourcing()
    s.source(last_month, last_month)
 
def source_csv_to_sqlite_all():
    logger.config_root(level=logging.DEBUG)
    last_month = str(date_util.get_last_month())
    s = sourcing.Sourcing()
    s.init_dates(FIRST_DAY, last_month)
    s.source_csv_to_sqlite(s.CSV_DIR, s.DB_FILE, s.SQL_INSERT)

def test():
    logger.config_root(level=logging.DEBUG)
    s = sourcing.Sourcing()
    #s.source('1999-01-01', '2012-09-01')
    #s.source('2012-09-01', '2012-09-01')
    #s.source('2003-05-01', '2003-05-01')

 
 
def main():
    source_last_month()
 
if __name__ == '__main__':
    sys.exit(main())


./source_listed_co_statistics.py

import logging
import sys

import src.listed_co_statistics.sourcing as sourcing
import src.common.logger as logger
import src.common.date_util as date_util

FIRST_DAY = '1999-03-01'

def source_all():
    logger.config_root(level=logging.DEBUG)
    last_month = str(date_util.get_last_month())
    s = sourcing.Sourcing()
    s.source(FIRST_DAY, last_month)

def source_last_month():
    logger.config_root(level=logging.DEBUG)
    last_month = str(date_util.get_last_month())
    s = sourcing.Sourcing()
    s.source(last_month, last_month)
 
def source_csv_to_sqlite_all():
    logger.config_root(level=logging.DEBUG)
    last_month = str(date_util.get_last_month())
    s = sourcing.Sourcing()
    s.init_dates(FIRST_DAY, last_month)
    s.source_csv_to_sqlite(s.CSV_DIR, s.DB_FILE, s.SQL_INSERT)

def test():
    logger.config_root(level=logging.DEBUG)
    s = sourcing.Sourcing()
    #s.source('2000-09-01', '2000-09-01') # for the last report for 21 cols
    #s.source('2000-08-01', '2000-08-01') # for the first report for dirty cols
    #s.source('2012-09-01', '2012-09-01') # for this month
 
 
 
def main():
    source_last_month()
 
if __name__ == '__main__':
    sys.exit(main())



沒有留言:

張貼留言