在Rust中使用TensorFlow

15 8月

本文版权归作者所有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

转载自夜明的孤行灯

本文链接地址: https://www.huangyunkun.com/2022/08/15/use-tensorflow-with-rust/



TensorFlow是一个开源的机器学习框架,上手简单,且在大量商业场景使用,可靠性高。

机器学习常用的语言主要是python、c,官方也提供了Java版本sdk,示例也很丰富。对于Rust语言,虽然官方文档没有写,其实官方也提供了rust版本的绑定,项目名称就是rust。

在Cargo.toml中配置版本,这里的0.18.0对应TensorFlow 2.8.0

tensorflow = "0.18.0"

这个依赖中的-sys模块会根据实际情况进行处理,如果本地没有Tensorflow的库文件,就会自动下载。如果没有预编译好的库文件,或者用户指定要从源码编译,那就会在本地编译。

我这里使用的是frozen graph,在rust代码中新建一个Graph进行加载即可。

let mut graph = Graph::new();
        let model_file = MODEL_DIR
            .get_file("mobilenet/mobilenet_v2_1.4_224_frozen.pb")
            .unwrap();
        let label_file = MODEL_DIR.get_file("mobilenet/label.txt").unwrap();
        graph
            .import_graph_def(model_file.contents(), &ImportGraphDefOptions::new())
            .unwrap();

为了最后打包方便,这里使用了include_dir包,会把模型和最终产物打包到一起。

这里的模型是一个mobilenet的图片分类模型,输入为224像素。图片的处理使用image包。

let img = image::open(photo.get_store_path())?;
        let resized = image::imageops::thumbnail(&img, 224, 224);

        let mut flattened: Vec<f32> = Vec::new();
        for rgb in resized.pixels() {
            flattened.push(rgb[0] as f32 / 255.);
            flattened.push(rgb[1] as f32 / 255.);
            flattened.push(rgb[2] as f32 / 255.);
        }

        {
            let input = Tensor::new(&[1, 224, 224, 3])
                .with_values(&flattened)
                .unwrap();

            let session = Session::new(&SessionOptions::new(), &graph)?;
            let mut args = SessionRunArgs::new();

            args.add_feed(
                &graph.operation_by_name_required("input").unwrap(),
                0,
                &input,
            );
            let prediction = args.request_fetch(
                &graph.operation_by_name_required("MobilenetV2/Predictions/Softmax")?,
                0,
            );

            session.run(&mut args).unwrap();

            let prediction_res: Tensor<f32> = args.fetch(prediction)?;

            let mut i = 0;
            let mut json_vec: Vec<f32> = Vec::new();
            while i < prediction_res.len() {
                json_vec.push(prediction_res[i]);
                i += 1;
            }

            let label = labels.get(imax(&json_vec).unwrap()).unwrap();

            return Ok(label.to_string());
        }

整体使用还是很简单的,只是部分文档缺失,需要研究下。如果有其他语言的Tensorflow使用经验会好很多。

目前Tensorflow Hub做的很好了,模型可以直接从上面下载,Savemodel也可以直接加载,也可以转为frozen后使用。



本文版权归作者所有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

转载自夜明的孤行灯

本文链接地址: https://www.huangyunkun.com/2022/08/15/use-tensorflow-with-rust/

发表回复

您的电子邮箱地址不会被公开。