この記事は、 NTT docomo Business Advent Calendar 2025 7日目の記事です。 こんにちは。イノベーションセンターの加藤です。普段はコンピュータビジョンの技術開発やAI/機械学習(ML)システムの検証に取り組んでいます。 ディープラーニングの実装をしているときに、変数のshapeを管理するのはなかなか大変です。いつのまにか次元が増えていたり、想定外のshapeがやってきたりして実行時に落ちてしまった!というのは日常茶飯事だと思います。 こういった問題に対して静的解析で何とかできないかと試行錯誤した結果を共有します。 mypyプラグインを使うモチベーション mypyプラグインの作成 初期化 jaxtyping annotationを拾う テンソル作成関数を拾う テンソル計算 ここで限界が来た(Future work) 変数付きのshape記法 次元の四則演算 stubの是非 レイヤーの型注釈 まとめ mypyプラグインを使うモチベーション Pythonのプログラムを静的検査する方法のひとつに mypy があります。これはソースコードにつけられた型アノテーションに矛盾がないか調べてくれるもので、変な代入や演算由来のエラーを未然に防ぐことができます。 しかしながら、NumPyやPyTorchなどの一般的な数値計算ライブラリにはそれなりの型がアノテーションされているものの、せいぜいテンソルの型(intやfloatなど)どまりで次元(shape)については考慮されていないため、そのままでは次元の不一致などを検出できません。 jaxtyping などのライブラリは元の型を拡張して次元などをアノテーションできるようにしてくれますが、これらは実行時解析のみをサポートしており、mypyからは扱えません。 from torch import Tensor import torch from jaxtyping import Float32, jaxtyped from beartype import beartype as typechecker from typing_extensions import reveal_type @ jaxtyped (typechecker=typechecker) def f (x: Float32[Tensor, "1 224 224" ]) -> Float32[Tensor, "1 1000" ]: print ( "processing f" ) w = torch.randn( 1000 , 224 * 224 ) x_flat = x.view( 1 , 224 * 224 ) y = x_flat @ w.t() return y.view( 1 , 1000 ) x: Float32[Tensor, "1 224 224" ] = torch.randn( 1 , 224 , 224 ) y = torch.randn( 1 , 224 , 225 ) print ( "f(x)" ) reveal_type(f(x)) # OK print ( "f(y)" ) reveal_type(f(y)) # NG """ 実行時は引数に誤ったshapeを渡した時点でエラー > python .\example.py f(x) processing f Runtime type is 'Tensor' f(y) Traceback (most recent call last): ... しかしmypyでは検出できない > mypy .\example.py example.py:19: note: Revealed type is "torch._tensor.Tensor" example.py:20: note: Revealed type is "torch._tensor.Tensor" Success: no issues found in 1 source file """ 結局プログラミングの段階ではあくまで可読性を高めるための注釈に留まり、実行時はお祈りしながら終了を待つことになります。 そこで本稿ではmypyプラグインを実装してjaxtypingの型に対する処理を追加することで、次元の整合性を実行前に検証できないかトライしてみました。もしこれができれば、mypyを使って次元込みの静的検査ができ、Visual Studio Codeのmypy拡張と連携すればプログラミング中もテンソルの次元を追うことができるようになります。 mypyプラグインの作成 初期化 uv でプロジェクトを新規作成します。 $ uv init --name jaxmy --lib Initialized project `jaxmy` $ uv add mypy jaxtyping $ uv add torch numpy pytest --optional tests src/jaxmy/mypy_plugin.py にプラグインスクリプトを作成します。 from typing import Any, Optional, List, Tuple import re from mypy.plugin import Plugin class ShapePlugin (Plugin): pass # TODO def plugin (version: str ): print ( "Hello world! version:" , version) return ShapePlugin そしてmypy実行時に自作のpluginを紐づけるには以下のようにpyprojectを編集します。 [build-system] requires = [ "hatchling" ] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = [ "src/jaxmy" ] [tool.mypy] ignore_missing_imports = true plugins = [ "jaxmy.mypy_plugin" ] mypy_path = "$MYPY_CONFIG_FILE_DIR/src/jaxmy/stubs" (mypy_pathについては後述) これでuv環境からmypyを実行するとpluginが介入するようになります。 > uv run mypy example.py Hello world! version: 1 . 18 . 2 jaxtyping annotationを拾う まずはjaxtyping記法によるアノテーションを拾うところから始めます。 jaxtypingが提供する型の実体はmypyなどの静的解析時( typing.TYPE_CHECKING == True )と実行時で異なっており、静的解析時は jaxtyping._indirection 内で定義された以下のコードが読み込まれます。 from typing import ( Annotated as BFloat16, # noqa: F401 Annotated as Bool, # noqa: F401 Annotated as Complex, # noqa: F401 Annotated as Complex64, # noqa: F401 Annotated as Complex128, # noqa: F401 Annotated as Float, # noqa: F401 ... そのためあらゆる型は Annotated[T, ...] とみなされ、これは事前に T と解決してから静的解析が走ります。 これによってjaxtyping記法が型として正しくない表記であるにもかかわらずエディタやmypyのチェックをすり抜けているのですが、 Float32[Tensor, "1 224 224"] や Int8[Tensor, "3 224 224"] などがすべて Tensor という同じ型に置き換えられてしまうため静的解析が不可能になります。 そこで、stubを注入して呼び出しを捕捉することでプラグインから触れるようにします。 # stub/jaxtyping/__init__.pyi from typing import Any, Literal, NoReturn, Union, TypeVar, Generic _ArrayType = TypeVar( "_ArrayType" ) _Shape = TypeVar( "_Shape" ) class AbstractArray (Generic[_ArrayType, _Shape]): pass class UInt2 (AbstractArray[_ArrayType, _Shape]): ... class UInt4 (AbstractArray[_ArrayType, _Shape]): ... class UInt8 (AbstractArray[_ArrayType, _Shape]): ... """以下よしなに""" これでプラグインからは get_type_analyze_hook を通して jaxtyping.Float32 などのアノテーションを拾えるようになりました。ただしjaxtyping記法はshapeの部分が型注釈として許されない文字列リテラルであるため、これを有効な型に置き換える必要があります。これを怠るとmypyがshape部分を Any に置き換えてしまいます。 from typing import Any, Optional, List, Tuple import re from mypy.plugin import Plugin, FunctionContext, AnalyzeTypeContext from mypy.types import Instance, TupleType, Type, UnboundType, LiteralType, EllipsisType, RawExpressionType, TypeStrVisitor from mypy.checker import TypeChecker def parse_dimstr (dimstr: str ) -> Optional[List[ int ]]: """Parse a dimension string like "1 3 224 224" into a list of int.""" dims: List[ int ] = [] for dim in dimstr.split( " " ): dim = dim.strip() if dim.isdigit(): dims.append( int (dim)) else : return None return dims def dump_dimlist (dimlist: List[ int ]) -> str : """Dump a list of int back into a dimension string.""" dimstrs: List[ str ] = [] for dim in dimlist: dimstrs.append( str (dim)) return " " .join(dimstrs) def construct_instance (api: TypeAnalyzerPluginInterface, dtype: str , backend: Type, dim_list: List[ int ]) -> Type: """Construct an Instance of a jaxtyping type with the given dtype, backend, and shape.""" # TODO : 本当はFloatなどのUnion型にも対応すべきだが、とりあえず保留 # shape表現をLiteralで包みjaxtypingのinstanceを返す。 return Instance( api.named_type(f "jaxtyping.{dtype}" ).type, [backend, LiteralType(value=dump_dimlist(dim_list), fallback=api.named_type( "builtins.str" ))] ) def analyze_jaxtyping (ctx: AnalyzeTypeContext) -> Type: """Parse Dtype[Array, "shape"] to the mypy-friendly type Dtype[Array, Literal["shape"]].""" typ = ctx.type # UnboundType. 何のことかはわからない if len (typ.args) != 2 : return typ backend, shape = typ.args backend = ctx.api.analyze_type(backend) # UnboundTypeなbackendを解決 (Tensorなどのinstanceになる) if not isinstance (shape, RawExpressionType) or type (shape.literal_value) is not str : return backend # fallback dtype = typ.name # e.g., "Float32" dim_str = shape.literal_value # e.g., "1 224 224" dim_list = parse_dimstr(dim_str) # validationもかねてパース if dim_list is None : return backend # fallback return construct_instance(ctx.api, dtype, backend, dim_list) DTYPE_ANNOTS = { "UInt2" , "UInt4" , "UInt8" , ...} class ShapePlugin (Plugin): def get_type_analyze_hook (self, fullname: str ): m = re.match( r"jaxtyping\.(\w+)" , fullname) if m and m.group( 1 ) in DTYPE_ANNOTS: return analyze_jaxtyping def plugin (version: str ): return ShapePlugin 今回はjaxtypingのサブセットとして数値リテラルのみ(例: Float32[Tensor, "1 3 224 224"] )をサポートします。内部的には基本リテラルで持ち( Float32[Tensor, Literal["1 3 224 224"]] )、都度バラして型推論を行います。 ※ ちなみに Float32[Tensor, Literal["1 3 224 224"]] よりも取り回しのよい内部表現を使う手もありますが、mypyには検査対象のプログラムで呼ばれているモジュール(とビルトイン)しか扱えないという制約があります。そのため、何かいい感じのオリジナル型を導入したい場合はjaxtypingそのものを改造する必要があります。 テンソル作成関数を拾う これに加えて、 torch.zeros() などの初期化用の関数を get_function_hook によって捕捉し、これらのテンソルにjaxtyping用の型を付与します。 まずstubを作成してtorchを扱えるようにします。 # stubs/torch/__init__.pyi from torch._tensor import Tensor as Tensor from typing import Any def randn (*size: int , out= None , dtype= None , **kwargs) -> Tensor: ... def rand (*size: int , out= None , dtype= None , **kwargs) -> Tensor: ... def zeros (*size: int , out= None , dtype= None , **kwargs) -> Tensor: ... def ones (*size: int , out= None , dtype= None , **kwargs) -> Tensor: ... そしてhookを作成します。この手の関数は入力の自由度が高く、引数を手でパースするのがちょっと大変です。 INITIALIZER_NAMES = { "torch.randn" , "torch.rand" , "torch.zeros" , "torch.ones" , } dtype_mapper = { # mapping torch.dtype to jaxtyping type "float32" : "Float32" , "float" : "Float32" , "float64" : "Float64" , "double" : "Float64" , ... } def hook (fullname: str ): if fullname in INITIALIZER_NAMES: return construct_from_shape return None Argument = namedtuple( 'Argument' , [ 'arg_type' , 'arg_kind' , 'arg_name' , 'arg' ]) def transpose_funcargs (ctx: FunctionContext | MethodContext) -> dict [ str , Argument]: """[引数型], [引数名], ... を [(引数型,引数名,...)] にまとめる""" ctxdict = {} for i, name in enumerate (ctx.callee_arg_names): if len (ctx.arg_kinds[i]) == 0 : continue ctxdict[name] = Argument( arg_type=ctx.arg_types[i], arg_kind=ctx.arg_kinds[i], arg_name=ctx.arg_names[i], arg=ctx.args[i] ) return ctxdict def construct_from_shape (ctx: FunctionContext): if not isinstance (ctx.api, TypeChecker): return ctx.default_return_type # 失敗時は基本的にこれを返す ctxdict = transpose_funcargs(ctx) if "size" not in ctxdict: return ctx.default_return_type args = ctxdict[ "size" ].arg_type dimensions: List[Type] = [] # shape指定にはf(1,2,3)とf((1,2,3))の二通りあるので対応 if len (args) == 1 and isinstance (args[ 0 ], TupleType): dimensions.extend(args[ 0 ].items) else : dimensions.extend(args) # すべて数値定数であるときのみ対応する if all (( isinstance (dim, Instance) and dim.last_known_value is not None and type (dim.last_known_value.value) is int ) for dim in dimensions): shape_list = [dim.last_known_value.value for dim in dimensions] if "dtype" in ctxdict: # dtype指定があるとき dtype = ctxdict[ "dtype" ] dtype_argtype = dtype.arg_type[ 0 ] if isinstance (dtype_argtype, Instance) and dtype_argtype.type.fullname in [ "torch.dtype" ]: jaxtype = dtype_mapper.get(dtype.arg[ 0 ].name, None ) if jaxtype is None : ctx.api.fail( f "Unsupported dtype {ctxdict['args'][0].name} for torch function." , ctx.context ) return ctx.default_return_type # 指定の型とshapeからjaxtyping型 DType[Tensor, Literal["shape"]] を作る return construct_instance( ctx.api, jaxtype, ctx.api.named_type( "torch.Tensor" ), shape_list ) else : ctx.api.fail( f "Unsupported dtype {dtype_argtype} for torch function." , ctx.context ) return ctx.default_return_type return construct_instance( # デフォルトdtypeはfloat32 ctx.api, "Float32" , ctx.api.named_type( "torch.Tensor" ), shape_list ) return ctx.default_return_type これで torch.randn などの返り値型がTensorからjaxtypingになりました。 def g (x: Float32[Tensor, "3 224 224" ]): ... x: Float32[Tensor, "3 224 224" ] = torch.randn( 3 , 224 , 224 ) # OK y: Float32[Tensor, "3 224 226" ] = torch.randn( 3 , 224 , 224 ) # Incompatible types in assignment g(x) # OK テンソル計算 つぎはテンソル同士の演算を定義します。考慮すべきことは以下の3つです。 型が異なる時は"偉い"方に合わせる shape不一致の時はエラー shapeのブロードキャスト(片方の次元が1の時はもう片方に合わせてもよい) ですが、いったん型の方は無視します。 まず準備としてテンソルの演算子をstubに定義します。 # stub/jaxtyping/__init__.pyi Self = TypeVar( "Self" , bound= "AbstractArray[_ArrayType, _Shape]" ) class AbstractArray (Generic[_ArrayType, _Shape]): def __add__ (self: Self, other: Any): ... def __radd__ (self: Self, other: Any): ... def __iadd__ (self: Self, other: Any) -> Self: ... def __sub__ (self: Self, other: Any): ... def __rsub__ (self: Self, other: Any): ... def __isub__ (self: Self, other: Any) -> Self: ... def __mul__ (self: Self, other: Any): ... def __rmul__ (self: Self, other: Any): ... def __imul__ (self: Self, other: Any) -> Self: ... そしてこれを get_method_hook で捕捉します。 arithmetic_names = { "__add__" , "__radd__" , "__sub__" , "__rsub__" , "__mul__" , "__rmul__" , "__pow__" , "__div__" , "__rdiv__" , ... } def decompose_instance (typ: Instance) -> Optional[Tuple[ str , Type, List[ int ]]]: """Decompose a jaxtyping type into (backend type, shape as list of ints).""" if len (typ.args) != 2 : return None backend, shape = typ.args if not isinstance (shape, RawExpressionType) or type (shape.literal_value) is not str : return None dtype = typ.name # e.g., "Float32" dim_str = shape.literal_value # e.g, "1 224 224" dim_list = parse_dimstr(dim_str) if dim_list is None : return None return dtype, backend, dim_list def tensor_arithmetic (ctx: MethodContext): self_type = ctx.type other_type = ctx.arg_types[ 0 ][ 0 ] if isinstance (self_type, Instance) and isinstance (other_type, Instance): if self_type.type.fullname.startswith( "jaxtyping." ): self_result = decompose_instance(self_type) if self_result is None : ctx.api.fail( f "Unable to parse Self as jaxtyping {self_type}" , ctx.context ) return ctx.default_return_type self_dtype, self_backend, self_dims = self_result else : ctx.api.fail( f "Self must be jaxtyping {self_type}" , ctx.context ) return ctx.default_return_type if other_type.type.fullname.startswith( "jaxtyping." ): other_result = decompose_instance(other_type) if other_result is None : ctx.api.fail( f "Unable to parse Other as jaxtyping {other_type}" , ctx.context ) return ctx.default_return_type other_dtype, other_backend, other_dims = other_result elif other_type.type.fullname in ( "builtins.int" , "builtins.float" ): other_dtype = self_dtype other_backend = self_backend other_dims = [] # scalar if repr (self_backend) != repr (other_backend): ctx.api.fail( f "Backend mismatch: {self_backend} vs {other_backend}" , ctx.context ) return ctx.default_return_type out_backend = self_backend if self_dtype != other_dtype: ctx.api.fail( f "Dtype mismatch: {self_dtype} vs {other_dtype}" , ctx.context ) return ctx.default_return_type # TODO : promote dtype out_dtype = self_dtype if self_dims == other_dims: out_dims = self_dims else : # broadcast check longest = max ( len (self_dims), len (other_dims)) self_dims = [ 1 ] * (longest - len (self_dims)) + self_dims other_dims = [ 1 ] * (longest - len (other_dims)) + other_dims out_dims = [] for d1, d2 in zip (self_dims, other_dims): if d1 == d2: out_dims.append(d1) elif d1 == 1 : out_dims.append(d2) elif d2 == 1 : out_dims.append(d1) else : ctx.api.msg.fail( f "Shape mismatch: {self_dims} vs {other_dims}" , ctx.context ) return ctx.default_return_type # fail return construct_instance(ctx.api, out_dtype, out_backend, out_dims) ctx.api.fail( f "Unknown types for tensor arithmetic: {self_type} and {other_type}" , ctx.context ) return ctx.default_return_type class ShapePlugin (Plugin): def get_method_hook (self, fullname: str ): if fullname.startswith( "jaxtyping." ): # jaxtyping.Float32.__add__など if fullname.split( "." )[- 1 ] in arithmetic_names: return tensor_arithmetic 注意点として、どうも実行時と同じように __add__ から __radd__ へのフォールバックがなされているらしく、 __add__ の処理で api.fail によるエラーを吐いても、 __radd__ の型チェックが未実装のままだとそちらで解決したことになりエラーが消えてしまうようです。ちゃんと両方処理するか、フォールバック先を無条件でfailさせる必要があります。 これで以下のテストに対応できます。 x: Float32[Tensor, "3 224 224" ] = torch.randn( 3 , 224 , 224 ) # OK y: Float32[Tensor, "3 224 226" ] = torch.randn( 3 , 224 , 226 ) # OK reveal_type(x + x) # OK reveal_type(x * 2.0 ) # OK (scalar) reveal_type(torch.randn( 1 , 224 , 224 ) + x) # OK (broadcasting) reveal_type(x + y) # Shape mismatch: [3, 224, 224] vs [3, 224, 226] ここで限界が来た(Future work) この時点でテンソルの四則演算ができるようになりましたが、ここでギブアップしてしまいました。 実用レベルにするには以下のようにまだまだやるべきことが山のようにあります。 変数付きのshape記法 jaxtypingは"batch 3 height width"のような記法に対応しており、これができれば畳み込みニューラルネットワークなど入力画像のサイズを気にしないものにも型を付けることができます。 次元の四則演算 例えばテンソルを結合したときに次元を足し算したり、upsampleでは掛け算、downsampleでは割り算などをする必要があります。そしてこれは変数を許すと鬼のように難しくなります。 例えばUNetなどは画像をdownsampleしたのちupsampleしますが、downsampleでの割り算は小数切り捨てなのでupsampleしても元に戻るとは限りません。つまり割り算と掛け算を縮約することができないため、 batch 3 height//8*8 width//8*8 のようなshapeが batch 3 height width と一致するかなどの検証をする必要があります。 これはあまりにも辛いので、「 height は8で割り切れる」のような注釈をjaxtypingに新しく設けることで割り算をうまく処理するというのが無難そうです(こうすることでUNetに中途半端なサイズの画像を入れてバグらせるというのも回避できます)。 stubの是非 プラグインがjaxtypingやPyTorchなどのライブラリ由来の型を拾うためにstubを使いましたが、果たしてこの使い方が正しいのかという懸念があります。もっとエレガントな方法はないのでしょうか…… レイヤーの型注釈 PyTorchを扱うからにはnn.Moduleに対応する必要があるでしょう。ですがニューラルネットのあらゆるレイヤーに対してjaxtypingの型検査を実装するというのは骨が折れます。 さらにmypyを基盤にする上でおそらく一番の鬼門は、PyTorchでは一般的な以下のコーディングです。 layers = nn.ModuleList([nn.Linear( 100 , 50 ), nn.ReLU(), nn.Linear( 50 , 10 )]) def forward (x: Tensor): for layer in layers: x = layer(x) return x mypyは変数の再代入があっても 対応できるらしい のですが、果たしてforループが回った後の型はつけられるか怪しいです。 まとめ この記事ではmypyプラグインの機能を利用して、PyTorchのソースコードに次元付きの型注釈がつけられないか挑戦してみました。それなりの機能は持たせられそうでしたが、実用的なレベルまでいけるかどうかは微妙そうです。