博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
修改pytorch和Keras预训练模型路径
阅读量:4154 次
发布时间:2019-05-25

本文共 2141 字,大约阅读时间需要 7 分钟。

目录

1、Pytorch预训练模型路径修改

Pytorch安装目录下有一个hub.py,改文件指定了预训练模型的加载位置。该文件存在于xxx\site-packages\torch,例如我的存在于“C:\ProgramData\Miniconda3\Lib\site-packages\torch”。

打开hub.py文件,找到load_state_dict_from_url函数,其中第二个参数
model_dir用于指定权重文件路径:model_dir (string, optional): directory in which to save the object。将该参数值由None改为权重文件位置即可,例如model_dir=‘D:/Models_Download/torch’。

def load_state_dict_from_url(url, model_dir='D:/Models_Download/torch', map_location=None, progress=True, check_hash=False, file_name=None):    r"""Loads the Torch serialized object at the given URL.    If downloaded file is a zip file, it will be automatically    decompressed.    If the object is already present in `model_dir`, it's deserialized and    returned.    The default value of `model_dir` is ``
/checkpoints`` where `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) progress (bool, optional): whether or not to display a progress bar to stderr. Default: True check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention ``filename-
.ext`` where ``
`` is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. Example: >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') """

2、Keras修改预训练模型位置

Keras安装路径内并没有一个文件来定义预训练模型位置,我只能在调用预训练模型的时候指定模型文件的路径(有没有更好的设置方法?)。

base_model = vgg19.VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False,                          weights='D:\\Models_Download\\keras\\vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

转载地址:http://farti.baihongyu.com/

你可能感兴趣的文章
Build GingerBread on 32 bit machine.
查看>>
How to make SD Card world wide writable
查看>>
Detecting Memory Leaks in Kernel
查看>>
Linux initial RAM disk (initrd) overview
查看>>
Timestamping Linux kernel printk output in dmesg for fun and profit
查看>>
There's Much More than Intel/AMD Inside
查看>>
CentOS7 安装MySQL 5.6.43
查看>>
使用Java 导入/导出 Excel ----Jakarta POI
查看>>
本地tomcat 服务器内存不足
查看>>
IntelliJ IDAE 2018.2 汉化
查看>>
基于S5PV210的uboot移植中遇到的若干问题记录(一)DM9000网卡移植
查看>>
Openwrt源码下载与编译
查看>>
我和ip_conntrack不得不说的一些事
查看>>
Linux 查看端口使用情况
查看>>
文件隐藏
查看>>
两个linux内核rootkit--之二:adore-ng
查看>>
两个linux内核rootkit--之一:enyelkm
查看>>
关于linux栈的一个深层次的问题
查看>>
rootkit related
查看>>
配置文件的重要性------轻化操作
查看>>