mnist dataset 로드할 때, 503 error 해결방법
by Qerogrammnist에서 503 에러 나시는 분들, download_mnist함수의 base_url 변수 값만 바꿔서 돌리시면 csv로 데이터셋 받아볼 수 있습니다.
- 에러시 pip으로 라이브러리 설치하고 돌리시면 됩니다 -
* requests, numpy, pandas
# Author : qerogram
import numpy as np
import pandas as pd
import os, struct
import requests, hashlib, gzip
def getMd5(data) :
hash = hashlib.new("md5")
hash.update(data)
return hash.hexdigest()
def fileDownload(url) :
filename = getMd5(os.urandom(16))
res = requests.get(url)
res.raw.decode_content = True
f = open(filename, 'wb')
f.write(res.content)
f.close()
os.makedirs(filename + "_", exist_ok=True)
with open(filename +"_/" + filename, "wb") as out_f, gzip.GzipFile(filename) as zip_f:
out_f.write(zip_f.read())
return filename
def removeFile(filename) :
os.remove(filename + "_/" + filename)
os.rmdir(filename + "_")
os.remove(filename)
def download_mnist(method):
base_url = "https://ossci-datasets.s3.amazonaws.com/mnist/"
download_link = {
"train" : ('train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz'),
"test" : ('t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz')
}
images_path, labels_path = download_link[method]
label_file_name = fileDownload(base_url + labels_path)
with open(label_file_name + "_/" + label_file_name,'rb') as lbpath:
magic, n = struct.unpack('>II',lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
removeFile(label_file_name)
image_file_name = fileDownload(base_url + images_path)
with open(image_file_name + "_/" + image_file_name,'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII',imgpath.read(16))
print(f"count of row = {num}, count of column = {rows * cols}")
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), rows * cols)
removeFile(image_file_name)
return images, labels
if __name__ == '__main__':
X_train, y_train = download_mnist('train')
x_df = pd.DataFrame(X_train)
x_df.to_csv("train_dataset.csv", index=False)
y_df = pd.DataFrame(y_train)
y_df.to_csv("train_label.csv", index=False)
X_test, y_test = download_mnist('test')
x_df = pd.DataFrame(X_test)
x_df.to_csv("test_dataset.csv", index=False)
y_df = pd.DataFrame(y_test)
y_df.to_csv("test_label.csv", index=False)
Reference
[1] 파일 파싱 코드, m.blog.naver.com/PostView.nhn?blogId=msnayana&logNo=220917297905&proxyReferer=https:%2F%2Fwww.google.com%2F
'잡다한 것' 카테고리의 다른 글
Logstash grok test용 conf 파일 (0) | 2021.03.31 |
---|---|
KIBANA 설치 (0) | 2021.03.25 |
Let's compile Python Source Code into DLL using Python Embedding (0) | 2020.06.03 |
Alfred SSH Plugin 개발 (0) | 2020.05.29 |
쉘 스크립트로 mysql 쉽게 접속하기 (0) | 2018.01.31 |
블로그의 정보
Data+
Qerogram